xref: /aosp_15_r20/external/pytorch/test/test_torch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data
7*da0073e9SAndroid Build Coastguard Workerimport numpy as np
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport contextlib
10*da0073e9SAndroid Build Coastguard Workerimport gc
11*da0073e9SAndroid Build Coastguard Workerimport io
12*da0073e9SAndroid Build Coastguard Workerimport inspect
13*da0073e9SAndroid Build Coastguard Workerimport itertools
14*da0073e9SAndroid Build Coastguard Workerimport math
15*da0073e9SAndroid Build Coastguard Workerimport random
16*da0073e9SAndroid Build Coastguard Workerimport re
17*da0073e9SAndroid Build Coastguard Workerimport copy
18*da0073e9SAndroid Build Coastguard Workerimport os
19*da0073e9SAndroid Build Coastguard Workerimport tempfile
20*da0073e9SAndroid Build Coastguard Workerimport unittest
21*da0073e9SAndroid Build Coastguard Workerimport warnings
22*da0073e9SAndroid Build Coastguard Workerimport types
23*da0073e9SAndroid Build Coastguard Workerimport pickle
24*da0073e9SAndroid Build Coastguard Workerimport textwrap
25*da0073e9SAndroid Build Coastguard Workerimport subprocess
26*da0073e9SAndroid Build Coastguard Workerimport weakref
27*da0073e9SAndroid Build Coastguard Workerimport sys
28*da0073e9SAndroid Build Coastguard Workerimport copyreg
29*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
30*da0073e9SAndroid Build Coastguard Workerfrom itertools import product, combinations, permutations, chain
31*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
32*da0073e9SAndroid Build Coastguard Workerfrom torch import multiprocessing as mp
33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_optimizers import (
35*da0073e9SAndroid Build Coastguard Worker    optim_db, optims, _get_optim_inputs_including_global_cliquey_kwargs)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
38*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
39*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
40*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
41*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
42*da0073e9SAndroid Build Coastguard Worker    skipCUDAMemoryLeakCheckIf, BytesIOContext,
43*da0073e9SAndroid Build Coastguard Worker    skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
44*da0073e9SAndroid Build Coastguard Worker    wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
45*da0073e9SAndroid Build Coastguard Worker    bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like,
46*da0073e9SAndroid Build Coastguard Worker    AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo)
47*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler
48*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
49*da0073e9SAndroid Build Coastguard Worker    expectedFailureMeta,
50*da0073e9SAndroid Build Coastguard Worker    expectedFailureXLA,
51*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
52*da0073e9SAndroid Build Coastguard Worker    onlyCUDA, onlyCPU,
53*da0073e9SAndroid Build Coastguard Worker    dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
54*da0073e9SAndroid Build Coastguard Worker    skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
55*da0073e9SAndroid Build Coastguard Worker    get_all_device_types, skipXLA)
56*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple
57*da0073e9SAndroid Build Coastguard Workerimport torch.backends.quantized
58*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.data
59*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
60*da0073e9SAndroid Build Coastguard Worker    tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN, TEST_MULTIGPU,
61*da0073e9SAndroid Build Coastguard Worker    _create_scaling_case, _create_scaling_models_optimizers)
62*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_mkldnn import bf32_on_and_off
63*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
64*da0073e9SAndroid Build Coastguard Worker    floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
65*da0073e9SAndroid Build Coastguard Worker    all_types_and, floating_types, floating_and_complex_types, integral_types_and,
66*da0073e9SAndroid Build Coastguard Worker    get_all_qint_dtypes,
67*da0073e9SAndroid Build Coastguard Worker)
68*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerif TEST_WITH_TORCHINDUCTOR:
71*da0073e9SAndroid Build Coastguard Worker    from torch._inductor.test_case import TestCase
72*da0073e9SAndroid Build Coastguard Workerelse:
73*da0073e9SAndroid Build Coastguard Worker    from torch.testing._internal.common_utils import TestCase  # type: ignore[assignment]
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker# Protects against includes accidentally setting the default dtype
77*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() is torch.float32
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
80*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
81*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard WorkerAMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
86*da0073e9SAndroid Build Coastguard Workerdef torch_vital_set(value):
87*da0073e9SAndroid Build Coastguard Worker    stash = None
88*da0073e9SAndroid Build Coastguard Worker    if 'TORCH_VITAL' in os.environ:
89*da0073e9SAndroid Build Coastguard Worker        stash = os.environ['TORCH_VITAL']
90*da0073e9SAndroid Build Coastguard Worker    os.environ['TORCH_VITAL'] = value
91*da0073e9SAndroid Build Coastguard Worker    try:
92*da0073e9SAndroid Build Coastguard Worker        yield
93*da0073e9SAndroid Build Coastguard Worker    finally:
94*da0073e9SAndroid Build Coastguard Worker        if stash:
95*da0073e9SAndroid Build Coastguard Worker            os.environ['TORCH_VITAL'] = stash
96*da0073e9SAndroid Build Coastguard Worker        else:
97*da0073e9SAndroid Build Coastguard Worker            del os.environ['TORCH_VITAL']
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker# Tests Vital Signs for Torch
100*da0073e9SAndroid Build Coastguard Worker# FIXME: document or deprecate whatever this is
101*da0073e9SAndroid Build Coastguard Workerclass TestBasicVitalSigns(TestCase):
102*da0073e9SAndroid Build Coastguard Worker    def test_basic_vitals(self):
103*da0073e9SAndroid Build Coastguard Worker        with torch_vital_set(''):
104*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.vitals_enabled())
105*da0073e9SAndroid Build Coastguard Worker        with torch_vital_set('ON'):
106*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.vitals_enabled())
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    def test_basic_vitals_read_write(self):
109*da0073e9SAndroid Build Coastguard Worker        with torch_vital_set('ON'):
110*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.vitals_enabled())
111*da0073e9SAndroid Build Coastguard Worker            # This tests the code path of setting a vital
112*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.set_vital('Dataloader', 'basic_unit_test', 'TEST_VALUE_STRING'))
113*da0073e9SAndroid Build Coastguard Worker            self.assertIn('TEST_VALUE_STRING', torch.read_vitals())
114*da0073e9SAndroid Build Coastguard Worker            self.assertIn('CUDA.used', torch.read_vitals())
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def test_dataloader_vitals(self):
117*da0073e9SAndroid Build Coastguard Worker        with torch_vital_set('ON'):
118*da0073e9SAndroid Build Coastguard Worker            inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
119*da0073e9SAndroid Build Coastguard Worker            tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
120*da0073e9SAndroid Build Coastguard Worker            dataset = torch.utils.data.TensorDataset(inps, tgts)
121*da0073e9SAndroid Build Coastguard Worker            loader = torch.utils.data.DataLoader(dataset, batch_size=2)
122*da0073e9SAndroid Build Coastguard Worker            self.assertIn('Dataloader.enabled\t\t True', torch.read_vitals())
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker# FIXME: document or deprecate whatever this is
125*da0073e9SAndroid Build Coastguard Workerclass TestVitalSignsCuda(TestCase):
126*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
127*da0073e9SAndroid Build Coastguard Worker    def test_cuda_vitals_gpu_only(self, device):
128*da0073e9SAndroid Build Coastguard Worker        with torch_vital_set('ON'):
129*da0073e9SAndroid Build Coastguard Worker            self.assertIn('CUDA.used\t\t true', torch.read_vitals())
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Workeris_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Workerclass TestTorchDeviceType(TestCase):
135*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    # TODO: move all tensor creation to common ops
138*da0073e9SAndroid Build Coastguard Worker    def _rand_shape(self, dim, min_size, max_size):
139*da0073e9SAndroid Build Coastguard Worker        shape = []
140*da0073e9SAndroid Build Coastguard Worker        for i in range(dim):
141*da0073e9SAndroid Build Coastguard Worker            shape.append(random.randint(min_size, max_size))
142*da0073e9SAndroid Build Coastguard Worker        return tuple(shape)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    # Validates that mathematical constants are defined properly, as required by
145*da0073e9SAndroid Build Coastguard Worker    # the Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html)
146*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
147*da0073e9SAndroid Build Coastguard Worker    def test_constants(self, device):
148*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(torch.e, float)
149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.e, math.e, atol=0, rtol=0)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(torch.pi, float)
152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.pi, math.pi, atol=0, rtol=0)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(torch.nan, float)
155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nan, math.nan, equal_nan=True)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(torch.inf, float)
158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.inf, math.inf)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
161*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
162*da0073e9SAndroid Build Coastguard Worker            torch.bool, torch.float32, torch.complex64, torch.float64,
163*da0073e9SAndroid Build Coastguard Worker            torch.complex128, torch.uint16, torch.uint32, torch.uint64)
164*da0073e9SAndroid Build Coastguard Worker    def test_bytes_to_scalar(self, device, dtype):
165*da0073e9SAndroid Build Coastguard Worker        def rand_byte():
166*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bool:
167*da0073e9SAndroid Build Coastguard Worker                return torch.randint(0, 2, ()).item()
168*da0073e9SAndroid Build Coastguard Worker            else:
169*da0073e9SAndroid Build Coastguard Worker                return torch.randint(0, 256, ()).item()
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        element_size = torch._utils._element_size(dtype)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
174*da0073e9SAndroid Build Coastguard Worker            bytes_list = [rand_byte() for _ in range(element_size)]
175*da0073e9SAndroid Build Coastguard Worker            scalar = bytes_to_scalar(bytes_list, dtype, device)
176*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
179*da0073e9SAndroid Build Coastguard Worker            torch.bool, torch.float32, torch.complex64, torch.float64,
180*da0073e9SAndroid Build Coastguard Worker            torch.complex128, torch.uint16, torch.uint32, torch.uint64)
181*da0073e9SAndroid Build Coastguard Worker    def test_storage(self, device, dtype):
182*da0073e9SAndroid Build Coastguard Worker        v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9)
183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.storage()[0], v[0][0])
184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.storage()[14], v[2][4])
185*da0073e9SAndroid Build Coastguard Worker        v_s = v.storage()
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        for el_num in range(v.numel()):
188*da0073e9SAndroid Build Coastguard Worker            dim0 = el_num // v.size(1)
189*da0073e9SAndroid Build Coastguard Worker            dim1 = el_num % v.size(1)
190*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
191*da0073e9SAndroid Build Coastguard Worker                v_s[el_num],
192*da0073e9SAndroid Build Coastguard Worker                v[dim0][dim1])
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        v_s_byte = v.storage().untyped()
195*da0073e9SAndroid Build Coastguard Worker        el_size = v.element_size()
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        for el_num in range(v.numel()):
198*da0073e9SAndroid Build Coastguard Worker            start = el_num * el_size
199*da0073e9SAndroid Build Coastguard Worker            end = start + el_size
200*da0073e9SAndroid Build Coastguard Worker            dim0 = el_num // v.size(1)
201*da0073e9SAndroid Build Coastguard Worker            dim1 = el_num % v.size(1)
202*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
203*da0073e9SAndroid Build Coastguard Worker                bytes_to_scalar(v_s_byte[start:end], dtype, device),
204*da0073e9SAndroid Build Coastguard Worker                v[dim0][dim1])
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
207*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
208*da0073e9SAndroid Build Coastguard Worker            torch.bool, torch.float32, torch.complex64, torch.float64,
209*da0073e9SAndroid Build Coastguard Worker            torch.complex128, torch.quint8, torch.qint8, torch.qint32,
210*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2)
211*da0073e9SAndroid Build Coastguard Worker    def test_storage_setitem(self, device, dtype):
212*da0073e9SAndroid Build Coastguard Worker        # Skip quantized dtypes for CUDA, since they're not supported
213*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
214*da0073e9SAndroid Build Coastguard Worker            if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]:
215*da0073e9SAndroid Build Coastguard Worker                return
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        storage_type_name = torch.storage._dtype_to_storage_type_map()[dtype]
218*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
219*da0073e9SAndroid Build Coastguard Worker            storage_type = eval('torch.cuda.' + storage_type_name)
220*da0073e9SAndroid Build Coastguard Worker        else:
221*da0073e9SAndroid Build Coastguard Worker            storage_type = eval('torch.' + storage_type_name)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        N = 10
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        s = storage_type(N)
226*da0073e9SAndroid Build Coastguard Worker        s[:] = 0
227*da0073e9SAndroid Build Coastguard Worker        l = [0] * N
228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, storage_type(l))
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        for i in range(N):
231*da0073e9SAndroid Build Coastguard Worker            s[i] = i
232*da0073e9SAndroid Build Coastguard Worker            l[i] = i
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, storage_type(l))
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        l[2:7] = [1] * 5
237*da0073e9SAndroid Build Coastguard Worker        s[2:7] = 1
238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, storage_type(l))
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
241*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
242*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
243*da0073e9SAndroid Build Coastguard Worker    def test_tensor_storage_type(self, device, dtype):
244*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        module = torch.cuda if (torch.device(device).type == 'cuda') else torch
247*da0073e9SAndroid Build Coastguard Worker        expected_storage_type = getattr(module, torch.storage._dtype_to_storage_type_map()[dtype])
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.storage_type(), expected_storage_type)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
252*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
253*da0073e9SAndroid Build Coastguard Worker    def test_tensor_from_storage(self, device, dtype):
254*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
255*da0073e9SAndroid Build Coastguard Worker        a_s = a.storage()
256*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size())
257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
258*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size())
259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, c)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
262*da0073e9SAndroid Build Coastguard Worker            if error_dtype == dtype:
263*da0073e9SAndroid Build Coastguard Worker                continue
264*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'):
265*da0073e9SAndroid Build Coastguard Worker                error_storage = a.to(error_dtype).storage()
266*da0073e9SAndroid Build Coastguard Worker                torch.tensor(error_storage, device=device, dtype=dtype)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
269*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
270*da0073e9SAndroid Build Coastguard Worker    def test_set_storage(self, device, dtype):
271*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
272*da0073e9SAndroid Build Coastguard Worker        a_s = a.storage()
273*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size())
274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
275*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size())
276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, c)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
279*da0073e9SAndroid Build Coastguard Worker            if error_dtype == dtype:
280*da0073e9SAndroid Build Coastguard Worker                continue
281*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'):
282*da0073e9SAndroid Build Coastguard Worker                error_storage = a.to(error_dtype).storage()
283*da0073e9SAndroid Build Coastguard Worker                b = torch.tensor([], device=device, dtype=dtype).set_(error_storage)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def _check_storage_meta(self, s, s_check):
286*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
287*da0073e9SAndroid Build Coastguard Worker            isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and
288*da0073e9SAndroid Build Coastguard Worker            isinstance(s_check, type(s)),
289*da0073e9SAndroid Build Coastguard Worker            (
290*da0073e9SAndroid Build Coastguard Worker                's and s_check must both be one of UntypedStorage or '
291*da0073e9SAndroid Build Coastguard Worker                'TypedStorage, but got'
292*da0073e9SAndroid Build Coastguard Worker                f' {type(s).__name__} and {type(s_check).__name__}'))
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.device.type, 'meta')
295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.nbytes(), s_check.nbytes())
296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.size(), s_check.size())
297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.data_ptr(), 0)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, r'Not available'):
300*da0073e9SAndroid Build Coastguard Worker            s[0]
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        if isinstance(s, torch.TypedStorage):
303*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s.dtype, s_check.dtype)
304*da0073e9SAndroid Build Coastguard Worker            self._check_storage_meta(s.untyped(), s_check.untyped())
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
307*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
308*da0073e9SAndroid Build Coastguard Worker    def test_typed_storage_meta(self, device, dtype):
309*da0073e9SAndroid Build Coastguard Worker        args_list = [
310*da0073e9SAndroid Build Coastguard Worker            [],
311*da0073e9SAndroid Build Coastguard Worker            [0],
312*da0073e9SAndroid Build Coastguard Worker            [100],
313*da0073e9SAndroid Build Coastguard Worker            [[1, 2, 3, 4, 5, 6]],
314*da0073e9SAndroid Build Coastguard Worker        ]
315*da0073e9SAndroid Build Coastguard Worker        for args in args_list:
316*da0073e9SAndroid Build Coastguard Worker            s_check = torch.TypedStorage(*args, dtype=dtype, device=device)
317*da0073e9SAndroid Build Coastguard Worker            s = torch.TypedStorage(*args, dtype=dtype, device='meta')
318*da0073e9SAndroid Build Coastguard Worker            self._check_storage_meta(s, s_check)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
321*da0073e9SAndroid Build Coastguard Worker    def test_untyped_storage_meta(self, device):
322*da0073e9SAndroid Build Coastguard Worker        args_list = [
323*da0073e9SAndroid Build Coastguard Worker            [],
324*da0073e9SAndroid Build Coastguard Worker            [0],
325*da0073e9SAndroid Build Coastguard Worker            [100],
326*da0073e9SAndroid Build Coastguard Worker            [[1, 2, 3, 4, 5, 6]],
327*da0073e9SAndroid Build Coastguard Worker        ]
328*da0073e9SAndroid Build Coastguard Worker        for args in args_list:
329*da0073e9SAndroid Build Coastguard Worker            s_check = torch.UntypedStorage(*args, device=device)
330*da0073e9SAndroid Build Coastguard Worker            s = torch.UntypedStorage(*args, device='meta')
331*da0073e9SAndroid Build Coastguard Worker            self._check_storage_meta(s, s_check)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
334*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
335*da0073e9SAndroid Build Coastguard Worker    def test_storage_meta_from_tensor(self, device, dtype):
336*da0073e9SAndroid Build Coastguard Worker        t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
337*da0073e9SAndroid Build Coastguard Worker        t = t_check.to('meta')
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        s_check = t_check.storage()
340*da0073e9SAndroid Build Coastguard Worker        s = t.storage()
341*da0073e9SAndroid Build Coastguard Worker        self._check_storage_meta(s, s_check)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
344*da0073e9SAndroid Build Coastguard Worker    def test_storage_meta_errors(self, device, dtype):
345*da0073e9SAndroid Build Coastguard Worker        s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
348*da0073e9SAndroid Build Coastguard Worker            s0.cpu()
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
351*da0073e9SAndroid Build Coastguard Worker            s0._share_fd_cpu_()
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
354*da0073e9SAndroid Build Coastguard Worker            s0._share_filename_cpu_()
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
357*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
358*da0073e9SAndroid Build Coastguard Worker                s0.cuda()
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'only available on CUDA'):
361*da0073e9SAndroid Build Coastguard Worker                s0._share_cuda_()
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"cannot pin 'torch.storage.UntypedStorage' only CPU memory can be pinned"):
364*da0073e9SAndroid Build Coastguard Worker                s0.pin_memory()
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
367*da0073e9SAndroid Build Coastguard Worker            s0.share_memory_()
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, r'Not available'):
370*da0073e9SAndroid Build Coastguard Worker            s0.tolist()
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
373*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
374*da0073e9SAndroid Build Coastguard Worker                s0._write_file(f, True, True, s0.element_size())
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
377*da0073e9SAndroid Build Coastguard Worker            s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
380*da0073e9SAndroid Build Coastguard Worker                s1.copy_(s0)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
383*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
384*da0073e9SAndroid Build Coastguard Worker    def test_storage_meta_ok(self, device, dtype):
385*da0073e9SAndroid Build Coastguard Worker        s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        # This is OK, it changes the meta storage size without allocating
388*da0073e9SAndroid Build Coastguard Worker        s0.resize_(10)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
391*da0073e9SAndroid Build Coastguard Worker    def test_module_share_memory(self):
392*da0073e9SAndroid Build Coastguard Worker        # Test fix for issue #80733
393*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/80733
394*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(3, 1)
395*da0073e9SAndroid Build Coastguard Worker        model_cuda = model.to('cuda')
396*da0073e9SAndroid Build Coastguard Worker        model.share_memory()
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
399*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy(self, device, dtype):
400*da0073e9SAndroid Build Coastguard Worker        from copy import deepcopy
401*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5, dtype=dtype, device=device)
402*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5, dtype=dtype, device=device)
403*da0073e9SAndroid Build Coastguard Worker        c = a.view(25)
404*da0073e9SAndroid Build Coastguard Worker        q = [a, [a.storage(), b.storage()], b, c]
405*da0073e9SAndroid Build Coastguard Worker        w = deepcopy(q)
406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w[0], q[0], atol=0, rtol=0)
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w[1][0], q[1][0], atol=0, rtol=0)
408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w[1][1], q[1][1], atol=0, rtol=0)
409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w[1], q[1], atol=0, rtol=0)
410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w[2], q[2], atol=0, rtol=0)
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker        # Check that deepcopy preserves sharing
413*da0073e9SAndroid Build Coastguard Worker        w[0].add_(1)
414*da0073e9SAndroid Build Coastguard Worker        for i in range(a.numel()):
415*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(w[1][0][i], q[1][0][i] + 1)
416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w[3], c + 1)
417*da0073e9SAndroid Build Coastguard Worker        w[2].sub_(1)
418*da0073e9SAndroid Build Coastguard Worker        for i in range(a.numel()):
419*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(w[1][1][i], q[1][1][i] - 1)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        # Check that deepcopy preserves attributes
422*da0073e9SAndroid Build Coastguard Worker        a.foo = 3
423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(deepcopy(a).foo, 3)
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
426*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_scalar(self, device, dtype):
427*da0073e9SAndroid Build Coastguard Worker        from copy import deepcopy
428*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(5, dtype=dtype, device=device)
429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(), deepcopy(a).size())
430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, deepcopy(a))
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def check_internal_mem_overlap(self, inplace_op, num_inputs,
433*da0073e9SAndroid Build Coastguard Worker                                   dtype, device,
434*da0073e9SAndroid Build Coastguard Worker                                   expected_failure=False):
435*da0073e9SAndroid Build Coastguard Worker        if isinstance(inplace_op, str):
436*da0073e9SAndroid Build Coastguard Worker            inplace_op = getattr(torch.Tensor, inplace_op)
437*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
438*da0073e9SAndroid Build Coastguard Worker        inputs = [input] + [torch.randn_like(input)
439*da0073e9SAndroid Build Coastguard Worker                            for i in range(num_inputs - 1)]
440*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
441*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'single memory location'):
442*da0073e9SAndroid Build Coastguard Worker                inplace_op(*inputs)
443*da0073e9SAndroid Build Coastguard Worker        else:
444*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
445*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'single memory location'):
446*da0073e9SAndroid Build Coastguard Worker                    inplace_op(*inputs)
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker    def unary_check_input_output_mem_overlap(self, data, sz, op,
449*da0073e9SAndroid Build Coastguard Worker                                             expected_failure=False):
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        def _test(op, output, input):
452*da0073e9SAndroid Build Coastguard Worker            output_exp = torch.empty_like(output)
453*da0073e9SAndroid Build Coastguard Worker            op(input, out=output_exp)
454*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        # output is identical to input:
457*da0073e9SAndroid Build Coastguard Worker        _test(op, output=data[0:sz], input=data[0:sz])
458*da0073e9SAndroid Build Coastguard Worker        # output and input are independent:
459*da0073e9SAndroid Build Coastguard Worker        _test(op, output=data[0:sz], input=data[sz:2 * sz])
460*da0073e9SAndroid Build Coastguard Worker        # output partially overlaps with input:
461*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
462*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
463*da0073e9SAndroid Build Coastguard Worker                _test(op, data[0:sz], data[1:sz + 1])
464*da0073e9SAndroid Build Coastguard Worker        else:
465*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
466*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
467*da0073e9SAndroid Build Coastguard Worker                    _test(op, data[0:sz], data[1:sz + 1])
468*da0073e9SAndroid Build Coastguard Worker        # output is transpose of input:
469*da0073e9SAndroid Build Coastguard Worker        length = int(math.sqrt(sz))
470*da0073e9SAndroid Build Coastguard Worker        input = data[:length**2].view([length, length])
471*da0073e9SAndroid Build Coastguard Worker        out = input.t()
472*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
473*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
474*da0073e9SAndroid Build Coastguard Worker                _test(op, out, input)
475*da0073e9SAndroid Build Coastguard Worker        else:
476*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
477*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
478*da0073e9SAndroid Build Coastguard Worker                    _test(op, out, input)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def ternary_check_input_output_mem_overlap(self, op, device,
481*da0073e9SAndroid Build Coastguard Worker                                               expected_failure=False):
482*da0073e9SAndroid Build Coastguard Worker        sz = 9
483*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(2 * sz, device=device)
484*da0073e9SAndroid Build Coastguard Worker        other1 = torch.randn(sz, device=device)
485*da0073e9SAndroid Build Coastguard Worker        other2 = torch.randn(sz, device=device)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
488*da0073e9SAndroid Build Coastguard Worker            data, sz, lambda input, out:
489*da0073e9SAndroid Build Coastguard Worker                op(input, other1.view(input.shape), other2.view(input.shape), out=out),
490*da0073e9SAndroid Build Coastguard Worker            expected_failure=expected_failure)
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
493*da0073e9SAndroid Build Coastguard Worker            data, sz, lambda input, out:
494*da0073e9SAndroid Build Coastguard Worker                op(other1.view(input.shape), input, other2.view(input.shape), out=out),
495*da0073e9SAndroid Build Coastguard Worker            expected_failure=expected_failure)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
498*da0073e9SAndroid Build Coastguard Worker            data, sz, lambda input, out:
499*da0073e9SAndroid Build Coastguard Worker                op(other1.view(input.shape), other2.view(input.shape), input, out=out),
500*da0073e9SAndroid Build Coastguard Worker            expected_failure=expected_failure)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    def _select_broadcastable_dims(self, dims_full=None):
503*da0073e9SAndroid Build Coastguard Worker        # select full dimensionality
504*da0073e9SAndroid Build Coastguard Worker        if dims_full is None:
505*da0073e9SAndroid Build Coastguard Worker            dims_full = []
506*da0073e9SAndroid Build Coastguard Worker            ndims = random.randint(1, 4)
507*da0073e9SAndroid Build Coastguard Worker            dims_full = [random.randint(1, 8) for _ in range(ndims)]
508*da0073e9SAndroid Build Coastguard Worker        else:
509*da0073e9SAndroid Build Coastguard Worker            ndims = len(dims_full)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        # select actual dimensions for ops:
512*da0073e9SAndroid Build Coastguard Worker        # larger: full ndims, individual sizes may be reduced
513*da0073e9SAndroid Build Coastguard Worker        # smaller: possibly reduced ndims, sizes may be reduced
514*da0073e9SAndroid Build Coastguard Worker        smaller_ndims = random.randint(1, ndims)
515*da0073e9SAndroid Build Coastguard Worker        dims_small = []
516*da0073e9SAndroid Build Coastguard Worker        dims_large = []
517*da0073e9SAndroid Build Coastguard Worker        for i in range(ndims - 1, -1, -1):
518*da0073e9SAndroid Build Coastguard Worker            j = random.randint(1, 3)
519*da0073e9SAndroid Build Coastguard Worker            if j == 1:  # no reduced singleton dimension
520*da0073e9SAndroid Build Coastguard Worker                ds = dims_full[i]
521*da0073e9SAndroid Build Coastguard Worker                dl = dims_full[i]
522*da0073e9SAndroid Build Coastguard Worker            elif j == 2:  # larger may have reduced singleton dimension
523*da0073e9SAndroid Build Coastguard Worker                ds = dims_full[i]
524*da0073e9SAndroid Build Coastguard Worker                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
525*da0073e9SAndroid Build Coastguard Worker            elif j == 3:  # smaller may have reduced singleton dimension
526*da0073e9SAndroid Build Coastguard Worker                ds = 1
527*da0073e9SAndroid Build Coastguard Worker                dl = dims_full[i]
528*da0073e9SAndroid Build Coastguard Worker            dims_large = [dl] + dims_large
529*da0073e9SAndroid Build Coastguard Worker            if len(dims_small) < smaller_ndims:
530*da0073e9SAndroid Build Coastguard Worker                dims_small = [ds] + dims_small
531*da0073e9SAndroid Build Coastguard Worker        return (dims_small, dims_large, dims_full)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    # collected tests of ops that used scalar_check in Declarations.cwrap for
534*da0073e9SAndroid Build Coastguard Worker    # correctness
535*da0073e9SAndroid Build Coastguard Worker    def test_scalar_check(self, device):
536*da0073e9SAndroid Build Coastguard Worker        zero_d = torch.randn((), device=device)
537*da0073e9SAndroid Build Coastguard Worker        one_d = torch.randn((1,), device=device)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # remainder
540*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.remainder(zero_d, zero_d).shape)
541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.remainder(zero_d, 2).shape)
542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.remainder(zero_d, one_d).shape)
543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.remainder(one_d, zero_d).shape)
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker        # fmod
546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.fmod(zero_d, zero_d).shape)
547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.fmod(zero_d, 2).shape)
548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.fmod(zero_d, one_d).shape)
549*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.fmod(one_d, zero_d).shape)
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker        # exp, cos, cosh, tan, atan, tanh, erf, erfc, reciprocal
552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.exp(zero_d).shape)
553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.cos(zero_d).shape)
554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.cosh(zero_d).shape)
555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.tan(zero_d).shape)
556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.atan(zero_d).shape)
557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.acosh(zero_d).shape)
558*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.asinh(zero_d).shape)
559*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.atanh(zero_d).shape)
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.tanh(zero_d).shape)
561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.erf(zero_d).shape)
562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.erfc(zero_d).shape)
563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.reciprocal(zero_d).shape)
564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.exp(one_d).shape)
565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.cos(one_d).shape)
566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.cosh(one_d).shape)
567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.tan(one_d).shape)
568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.atan(one_d).shape)
569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.acosh(one_d).shape)
570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.asinh(one_d).shape)
571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.atanh(one_d).shape)
572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.tanh(one_d).shape)
573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.erf(one_d).shape)
574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.erfc(one_d).shape)
575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.reciprocal(one_d).shape)
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        # clamp
578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.clamp(zero_d, min=0, max=1).shape)
579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.clamp(zero_d, min=0).shape)
580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.clamp(zero_d, max=1).shape)
581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.clamp(one_d, min=0, max=1).shape)
582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.clamp(one_d, min=0).shape)
583*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.clamp(one_d, max=1).shape)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        # cumsum, cumprod, cummax, cummin
586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.logcumsumexp(zero_d, 0).shape)
587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.cumsum(zero_d, 0).shape)
588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.cumprod(zero_d, 0).shape)
589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.cummax(zero_d, 0)[0].shape)
590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.cummin(zero_d, 0)[0].shape)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        # sort, topk
593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, False)])
594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, True)])
595*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, False)])
596*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, True)])
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker        # max, min
599*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.max(zero_d, zero_d).shape)
600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.max(one_d, zero_d).shape)
601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.max(zero_d, one_d).shape)
602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.min(zero_d, zero_d).shape)
603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.min(one_d, zero_d).shape)
604*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.min(zero_d, one_d).shape)
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker        zero_d_int = torch.tensor(1, device=device)
607*da0073e9SAndroid Build Coastguard Worker        one_d_int = torch.tensor([1], device=device)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        # lshift, rshift
610*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int >> zero_d_int).shape)
611*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int >> 1).shape)
612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int >> zero_d_int).shape)
613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (zero_d_int >> one_d_int).shape)
614*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int >> 1).shape)
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int << zero_d_int).shape)
617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int << 1).shape)
618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int << zero_d_int).shape)
619*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (zero_d_int << one_d_int).shape)
620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int << 1).shape)
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        # or
623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int | zero_d_int).shape)
624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int | 1).shape)
625*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int | zero_d_int).shape)
626*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (zero_d_int | one_d_int).shape)
627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int | 1).shape)
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        # and
630*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int & zero_d_int).shape)
631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), (zero_d_int & 1).shape)
632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int & zero_d_int).shape)
633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (zero_d_int & one_d_int).shape)
634*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), (one_d_int & 1).shape)
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        # clone
637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), zero_d.clone().shape)
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker        zero_d_bool = torch.tensor(True, device=device)
640*da0073e9SAndroid Build Coastguard Worker        one_d_bool = torch.tensor([True], device=device)
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        # masked_select
643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.masked_select(zero_d_bool, zero_d_bool).shape)
644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.masked_select(zero_d_bool, one_d_bool).shape)
645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.masked_select(one_d_bool, zero_d_bool).shape)
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker        zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device)
648*da0073e9SAndroid Build Coastguard Worker        one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device)
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker        # mode
651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)])
652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)])
653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(1,), (1,)], [x.shape for x in torch.mode(one_d, dim=0, keepdim=True)])
654*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.mode(one_d, dim=0, keepdim=False)])
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        # max
657*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=True)])
658*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=False)])
659*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(1,), (1,)], [x.shape for x in torch.max(one_d, dim=0, keepdim=True)])
660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.max(one_d, dim=0, keepdim=False)])
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker        # amax
663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=True).shape)
664*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=False).shape)
665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.amax(one_d, dim=0, keepdim=True).shape)
666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.amax(one_d, dim=0, keepdim=False).shape)
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker        # min
669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=True)])
670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=False)])
671*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(1,), (1,)], [x.shape for x in torch.min(one_d, dim=0, keepdim=True)])
672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(), ()], [x.shape for x in torch.min(one_d, dim=0, keepdim=False)])
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        # amin
675*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=True).shape)
676*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=False).shape)
677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.amin(one_d, dim=0, keepdim=True).shape)
678*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.amin(one_d, dim=0, keepdim=False).shape)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker        # set_
681*da0073e9SAndroid Build Coastguard Worker        zero_d_clone = zero_d.clone()
682*da0073e9SAndroid Build Coastguard Worker        one_d_clone = one_d.clone()
683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), zero_d_clone.set_(one_d.storage(), 0, (), ()).shape)
684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), zero_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape)
685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), one_d_clone.set_(one_d.storage(), 0, (), ()).shape)
686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), one_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape)
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), zero_d.clone().set_(zero_d).shape)
689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), one_d.clone().set_(zero_d).shape)
690*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), zero_d.clone().set_(one_d).shape)
691*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), one_d.clone().set_(one_d).shape)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        # take
694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.randn((2, 3), device=device).take(zero_d_int).shape)
695*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.randn((2, 3), device=device).take(one_d_int).shape)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        # gather
698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.gather(zero_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.gather(zero_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
700*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.gather(one_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.gather(one_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker        # normal
704*da0073e9SAndroid Build Coastguard Worker        # std must be >= 0
705*da0073e9SAndroid Build Coastguard Worker        zero_d_ge_0 = torch.rand((), device=device)
706*da0073e9SAndroid Build Coastguard Worker        # documentation says out shape matches shape of mean
707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.normal(zero_d, zero_d_ge_0).shape)
708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.normal(one_d, zero_d_ge_0).shape)
709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.normal(1, zero_d_ge_0).shape)
710*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), torch.normal(zero_d, 1).shape)
711*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1,), torch.normal(one_d, 1).shape)
712*da0073e9SAndroid Build Coastguard Worker        # TODO: this behavior differs on CPU and GPU, see https://github.com/pytorch/pytorch/issues/30480.
713*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual((), torch.normal(zero_d, one_d).shape)
714*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual((), torch.normal(1, one_d).shape)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        # convolutions.  Yes, we are testing nn.functional here; seems justified
717*da0073e9SAndroid Build Coastguard Worker        # given its similar to the other tests
718*da0073e9SAndroid Build Coastguard Worker        w = torch.randn(2, 1, 3, 3, device=device).div_(2).requires_grad_()
719*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=1))
720*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=2))
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker        # nll_loss -- verify input can't be 0-dimensional.
723*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, zero_d, reduction='none'))
724*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, one_d, reduction='none'))
725*da0073e9SAndroid Build Coastguard Worker        # verify output is 0-dimensional when reduction != 'none'
726*da0073e9SAndroid Build Coastguard Worker        for (input, target) in ((torch.randn(1, 1, device=device), torch.tensor([0], device=device)),
727*da0073e9SAndroid Build Coastguard Worker                                (torch.randn(1, 1, 1, 1, device=device), torch.tensor([[[0]]], device=device))):
728*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='mean').shape)
729*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='sum').shape)
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker    # Test that `torch._check_tensor_all` raises errors in the correct cases
732*da0073e9SAndroid Build Coastguard Worker    def test_check_tensor_all(self, device):
733*da0073e9SAndroid Build Coastguard Worker        default_message = 'Expected cond to be True'
734*da0073e9SAndroid Build Coastguard Worker        check_fn = torch._check_tensor_all
735*da0073e9SAndroid Build Coastguard Worker        expected_error = RuntimeError
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        # cond must be a tensor
738*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, 'cond must be a tensor'):
739*da0073e9SAndroid Build Coastguard Worker            check_fn(True)
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        # cond tensor must be boolean
742*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, 'cond tensor must have dtype torch.bool'):
743*da0073e9SAndroid Build Coastguard Worker            check_fn(torch.ones(1, device=device))
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker        test_sizes = [
746*da0073e9SAndroid Build Coastguard Worker            (),
747*da0073e9SAndroid Build Coastguard Worker            (1,),
748*da0073e9SAndroid Build Coastguard Worker            (10,),
749*da0073e9SAndroid Build Coastguard Worker            (1, 1),
750*da0073e9SAndroid Build Coastguard Worker            (1, 10),
751*da0073e9SAndroid Build Coastguard Worker            (10, 1),
752*da0073e9SAndroid Build Coastguard Worker            (10, 10),
753*da0073e9SAndroid Build Coastguard Worker            (1, 1, 1),
754*da0073e9SAndroid Build Coastguard Worker            (10, 1, 1),
755*da0073e9SAndroid Build Coastguard Worker            (1, 10, 1),
756*da0073e9SAndroid Build Coastguard Worker            (10, 10, 10),
757*da0073e9SAndroid Build Coastguard Worker        ]
758*da0073e9SAndroid Build Coastguard Worker        for size in test_sizes:
759*da0073e9SAndroid Build Coastguard Worker            t_all_true = torch.ones(size, dtype=torch.bool, device=device)
760*da0073e9SAndroid Build Coastguard Worker            t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker            # Should not raise error
763*da0073e9SAndroid Build Coastguard Worker            check_fn(t_all_true)
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, default_message):
766*da0073e9SAndroid Build Coastguard Worker                check_fn(t_all_false)
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker            if t_all_true.numel() > 1:
769*da0073e9SAndroid Build Coastguard Worker                t_all_true_but_one = t_all_true.clone()
770*da0073e9SAndroid Build Coastguard Worker                # Choose a random element to set to false
771*da0073e9SAndroid Build Coastguard Worker                idx = (random.choice(range(dim_size)) for dim_size in size)
772*da0073e9SAndroid Build Coastguard Worker                t_all_true_but_one[(..., *idx)] = False
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(expected_error, default_message):
775*da0073e9SAndroid Build Coastguard Worker                    check_fn(t_all_true_but_one)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker            # Test a simple failure message
778*da0073e9SAndroid Build Coastguard Worker            message = 'message'
779*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, message):
780*da0073e9SAndroid Build Coastguard Worker                check_fn(t_all_false, lambda: message)
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker            # Test message with tensor
783*da0073e9SAndroid Build Coastguard Worker            def message():
784*da0073e9SAndroid Build Coastguard Worker                return torch.arange(4)
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
787*da0073e9SAndroid Build Coastguard Worker                check_fn(t_all_false, message)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker            # Test format string message
790*da0073e9SAndroid Build Coastguard Worker            def message():
791*da0073e9SAndroid Build Coastguard Worker                return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
794*da0073e9SAndroid Build Coastguard Worker                check_fn(t_all_false, message)
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker    # Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
797*da0073e9SAndroid Build Coastguard Worker    def test_check_tensor_internal(self, device):
798*da0073e9SAndroid Build Coastguard Worker        test_sizes = [
799*da0073e9SAndroid Build Coastguard Worker            (),
800*da0073e9SAndroid Build Coastguard Worker            (1,),
801*da0073e9SAndroid Build Coastguard Worker            (10,),
802*da0073e9SAndroid Build Coastguard Worker            (1, 1),
803*da0073e9SAndroid Build Coastguard Worker            (1, 10),
804*da0073e9SAndroid Build Coastguard Worker            (10, 1),
805*da0073e9SAndroid Build Coastguard Worker            (10, 10),
806*da0073e9SAndroid Build Coastguard Worker            (1, 1, 1),
807*da0073e9SAndroid Build Coastguard Worker            (10, 1, 1),
808*da0073e9SAndroid Build Coastguard Worker            (1, 10, 1),
809*da0073e9SAndroid Build Coastguard Worker            (10, 10, 10),
810*da0073e9SAndroid Build Coastguard Worker        ]
811*da0073e9SAndroid Build Coastguard Worker        for size in test_sizes:
812*da0073e9SAndroid Build Coastguard Worker            t_all_true = torch.ones(size, dtype=torch.bool, device=device)
813*da0073e9SAndroid Build Coastguard Worker            t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker            # Should not raise error
816*da0073e9SAndroid Build Coastguard Worker            torch._test_check_tensor(t_all_true)
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
819*da0073e9SAndroid Build Coastguard Worker                torch._test_check_tensor(t_all_false)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker            if t_all_true.numel() > 1:
822*da0073e9SAndroid Build Coastguard Worker                t_all_true_but_one = t_all_true.clone()
823*da0073e9SAndroid Build Coastguard Worker                # Choose a random element to set to false
824*da0073e9SAndroid Build Coastguard Worker                idx = (random.choice(range(dim_size)) for dim_size in size)
825*da0073e9SAndroid Build Coastguard Worker                t_all_true_but_one[(..., *idx)] = False
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
828*da0073e9SAndroid Build Coastguard Worker                    torch._test_check_tensor(t_all_true_but_one)
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker    # Uses mismatched arange out size to trigger a warning
831*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
832*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering")
833*da0073e9SAndroid Build Coastguard Worker    def test_cpp_warnings_have_python_context(self, device):
834*da0073e9SAndroid Build Coastguard Worker        # Creates long string in advance to avoid a too-long Python line
835*da0073e9SAndroid Build Coastguard Worker        s = ".+Triggered internally at.+RangeFactories.+"
836*da0073e9SAndroid Build Coastguard Worker        # nvfuser deprecation warning filter
837*da0073e9SAndroid Build Coastguard Worker        warnings.filterwarnings("ignore", "torch::jit::fuser::cuda", UserWarning)
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker        def cpp_warn_fn():
840*da0073e9SAndroid Build Coastguard Worker            out = torch.empty((5,))
841*da0073e9SAndroid Build Coastguard Worker            torch.arange(0, 3, out=out)
842*da0073e9SAndroid Build Coastguard Worker            return out
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker        # Checks eager-mode cpp warning
845*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
846*da0073e9SAndroid Build Coastguard Worker            cpp_warn_fn()
847*da0073e9SAndroid Build Coastguard Worker            frameinfo = inspect.getframeinfo(inspect.currentframe())
848*da0073e9SAndroid Build Coastguard Worker            warning = w[0]
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker            # Checks for cpp context in the warning message
851*da0073e9SAndroid Build Coastguard Worker            escaped_warning_message = str(warning.message).encode('unicode_escape')
852*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None)
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker            # Checks the Python features of the warning
855*da0073e9SAndroid Build Coastguard Worker            # Note: the eager mode warning refers to the line in the function
856*da0073e9SAndroid Build Coastguard Worker            # that throws the warning.
857*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(frameinfo.lineno - 6, warning.lineno)
858*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        # Checks jitted cpp warning
861*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
862*da0073e9SAndroid Build Coastguard Worker            scripted_cpp_warn_fn = torch.jit.script(cpp_warn_fn)
863*da0073e9SAndroid Build Coastguard Worker            scripted_cpp_warn_fn()
864*da0073e9SAndroid Build Coastguard Worker            warning = w[0]
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker            # Checks for cpp context in the warning message
867*da0073e9SAndroid Build Coastguard Worker            escaped_warning_message = str(warning.message).encode('unicode_escape')
868*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None)
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker            # Checks the Python features of the warning
871*da0073e9SAndroid Build Coastguard Worker            # Note: the jitted warning's lineno refers to the call to the jitted
872*da0073e9SAndroid Build Coastguard Worker            # function, which in our test suite has a layer of indirection
873*da0073e9SAndroid Build Coastguard Worker            # that makes checking the Python lineno fragile
874*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker        # Checks jitted Python warning
877*da0073e9SAndroid Build Coastguard Worker        def warn_fn():
878*da0073e9SAndroid Build Coastguard Worker            warnings.warn("Warning!")
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker        # The jit mimics an eager-mode Python warning in this case
881*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
882*da0073e9SAndroid Build Coastguard Worker            scripted_warn_fn = torch.jit.script(warn_fn)
883*da0073e9SAndroid Build Coastguard Worker            scripted_warn_fn()
884*da0073e9SAndroid Build Coastguard Worker            frameinfo = inspect.getframeinfo(inspect.currentframe())
885*da0073e9SAndroid Build Coastguard Worker            warning = w[0]
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(re.search('Warning!', str(warning.message)) is not None)
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker            # Checks the Python features of the warning
890*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(frameinfo.lineno - 6, warning.lineno)
891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_testing
894*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
895*da0073e9SAndroid Build Coastguard Worker    def test_warn_always_caught(self, device):
896*da0073e9SAndroid Build Coastguard Worker        # Check that we can catch a TORCH_WARN_ONCE warning twice
897*da0073e9SAndroid Build Coastguard Worker        # since assertWarnsOnceRegex uses set_warn_always(True) which changes
898*da0073e9SAndroid Build Coastguard Worker        # TORCH_WARN_ONCE to TORCH_WARN
899*da0073e9SAndroid Build Coastguard Worker        a = np.arange(10)
900*da0073e9SAndroid Build Coastguard Worker        a.flags.writeable = False
901*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
902*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a)
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker        # OK, got it once, now try again
905*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
906*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a)
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker        # Make sure emitting two warnings will pass the assertWarnsOnceRegex
909*da0073e9SAndroid Build Coastguard Worker        # context manager
910*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
911*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a)
912*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a)
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
915*da0073e9SAndroid Build Coastguard Worker    def test_complex_half_experimental_warning(self, device):
916*da0073e9SAndroid Build Coastguard Worker        msg = 'ComplexHalf support is experimental'
917*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
918*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(3, dtype=torch.chalf, device=device)
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
921*da0073e9SAndroid Build Coastguard Worker            torch.rand(3, dtype=torch.chalf, device=device)
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
924*da0073e9SAndroid Build Coastguard Worker            torch.empty(3, dtype=torch.chalf, device=device)
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
927*da0073e9SAndroid Build Coastguard Worker            torch.ones(3, dtype=torch.chalf, device=device)
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
930*da0073e9SAndroid Build Coastguard Worker            torch.zeros(3, dtype=torch.chalf, device=device)
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
933*da0073e9SAndroid Build Coastguard Worker            torch.randn_like(t)
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
936*da0073e9SAndroid Build Coastguard Worker            torch.rand_like(t)
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
939*da0073e9SAndroid Build Coastguard Worker            torch.empty_like(t)
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
942*da0073e9SAndroid Build Coastguard Worker            torch.ones_like(t)
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
945*da0073e9SAndroid Build Coastguard Worker            torch.zeros_like(t)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
948*da0073e9SAndroid Build Coastguard Worker            # t + 1 allocates a new tensor for result using empty
949*da0073e9SAndroid Build Coastguard Worker            t + 1
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
952*da0073e9SAndroid Build Coastguard Worker    def test_dtypetensor_warnings(self, device):
953*da0073e9SAndroid Build Coastguard Worker        msg = 'The torch.cuda.*DtypeTensor constructors are no longer recommended'
954*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
955*da0073e9SAndroid Build Coastguard Worker            t = torch.cuda.FloatTensor([0])
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, msg):
958*da0073e9SAndroid Build Coastguard Worker            t = torch.cuda.DoubleTensor([0])
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker    def test_set_default_tensor_type_warnings(self, device):
961*da0073e9SAndroid Build Coastguard Worker        msg = '.*is deprecated as of PyTorch 2.1, please use torch.set_default_dtype().*'
962*da0073e9SAndroid Build Coastguard Worker        default_type = torch.tensor([]).type()
963*da0073e9SAndroid Build Coastguard Worker        try:
964*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsOnceRegex(UserWarning, msg):
965*da0073e9SAndroid Build Coastguard Worker                torch.set_default_tensor_type(torch.FloatTensor)
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
968*da0073e9SAndroid Build Coastguard Worker                with self.assertWarnsOnceRegex(UserWarning, msg):
969*da0073e9SAndroid Build Coastguard Worker                    torch.set_default_tensor_type(torch.cuda.FloatTensor)
970*da0073e9SAndroid Build Coastguard Worker        finally:
971*da0073e9SAndroid Build Coastguard Worker            torch.set_default_tensor_type(default_type)
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker    # TODO: this test should be in test_nn.py
974*da0073e9SAndroid Build Coastguard Worker    def test_conv_transposed_backward_agnostic_to_memory_format(self, device):
975*da0073e9SAndroid Build Coastguard Worker        in_channels = 64
976*da0073e9SAndroid Build Coastguard Worker        out_channels = 128
977*da0073e9SAndroid Build Coastguard Worker        scale_factor = 8
978*da0073e9SAndroid Build Coastguard Worker        batch_size = 8
979*da0073e9SAndroid Build Coastguard Worker        length = 16
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.ConvTranspose1d(
982*da0073e9SAndroid Build Coastguard Worker            in_channels, out_channels, kernel_size=scale_factor * 2, stride=scale_factor).to(device)
983*da0073e9SAndroid Build Coastguard Worker        layer_norm = torch.nn.LayerNorm(out_channels).to(device)
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker        input_ = torch.randn(batch_size, in_channels, length).to(device).contiguous()
986*da0073e9SAndroid Build Coastguard Worker        input_ = conv(input_).contiguous()
987*da0073e9SAndroid Build Coastguard Worker        input_ = layer_norm(input_.transpose(1, 2).contiguous()).contiguous()
988*da0073e9SAndroid Build Coastguard Worker        input_.sum().backward()
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        # 3d
991*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.ConvTranspose3d(3, 3, kernel_size=3).to(device)
992*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(batch_size, 3, length, length, length, device=device)
993*da0073e9SAndroid Build Coastguard Worker        out = conv(input)
994*da0073e9SAndroid Build Coastguard Worker        out.backward(torch.ones_like(out).transpose(-2, -1))
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    # TODO: this test should be in test_nn.py
997*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
998*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('12GB')
999*da0073e9SAndroid Build Coastguard Worker    def test_conv_transposed_large(self, device):
1000*da0073e9SAndroid Build Coastguard Worker        # ConvTranspose3d works for large input tensors (gh-32866)
1001*da0073e9SAndroid Build Coastguard Worker        in_channels = 64
1002*da0073e9SAndroid Build Coastguard Worker        out_channels = 128
1003*da0073e9SAndroid Build Coastguard Worker        kernel_size = 5
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.ConvTranspose3d(
1006*da0073e9SAndroid Build Coastguard Worker            in_channels, out_channels, kernel_size=kernel_size,
1007*da0073e9SAndroid Build Coastguard Worker            stride=2, padding=2, output_padding=1).to(device)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([1, 64, 8, 128, 172]).to(device)
1010*da0073e9SAndroid Build Coastguard Worker        y = conv(x)
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker    def test_is_set_to(self, device):
1013*da0073e9SAndroid Build Coastguard Worker        t1 = torch.empty(3, 4, 9, 10, device=device)
1014*da0073e9SAndroid Build Coastguard Worker        t2 = torch.empty(3, 4, 9, 10, device=device)
1015*da0073e9SAndroid Build Coastguard Worker        t3 = torch.tensor([], device=device).set_(t1)
1016*da0073e9SAndroid Build Coastguard Worker        t4 = t3.clone().resize_(12, 90)
1017*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_set_to(t2))
1018*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t1.is_set_to(t3))
1019*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t3.is_set_to(t1), "is_set_to should be symmetric")
1020*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_set_to(t4))
1021*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor([]).is_set_to(torch.tensor([])),
1022*da0073e9SAndroid Build Coastguard Worker                         "Tensors with no storages should not appear to be set "
1023*da0073e9SAndroid Build Coastguard Worker                         "to each other")
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([True, True], dtype=torch.bool, device=device)
1026*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([0], dtype=torch.bool, device=device).set_(t1)
1027*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t1.is_set_to(t2))
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        # test that sizes must match
1030*da0073e9SAndroid Build Coastguard Worker        t1 = torch.empty([2, 3, 4], device=device)
1031*da0073e9SAndroid Build Coastguard Worker        t2 = t1.view(4, 3, 2)
1032*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_set_to(t2))
1033*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t2.is_set_to(t1))
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker        # test that legacy empty size behavior used to be respected (i.e. all
1036*da0073e9SAndroid Build Coastguard Worker        # empty tensors were logically collapsed to size [0]).
1037*da0073e9SAndroid Build Coastguard Worker        t1 = torch.empty([2, 5, 0], device=device)
1038*da0073e9SAndroid Build Coastguard Worker        t2 = t1.view([0])
1039*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_set_to(t2))
1040*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t2.is_set_to(t1))
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/72650
1043*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1044*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1045*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1046*da0073e9SAndroid Build Coastguard Worker        "fn",
1047*da0073e9SAndroid Build Coastguard Worker        [
1048*da0073e9SAndroid Build Coastguard Worker            "dist", "atan2", "pow", "lerp", "add", "sub", "mul", "div", "fmod", "remainder", "eq", "ge", "gt", "le",
1049*da0073e9SAndroid Build Coastguard Worker            "lt", "max", "min", "ne", "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", "map",
1050*da0073e9SAndroid Build Coastguard Worker            "map2", "copy",
1051*da0073e9SAndroid Build Coastguard Worker        ],
1052*da0073e9SAndroid Build Coastguard Worker    )
1053*da0073e9SAndroid Build Coastguard Worker    def test_broadcast(self, fn, device):
1054*da0073e9SAndroid Build Coastguard Worker        # functions with three tensor arguments
1055*da0073e9SAndroid Build Coastguard Worker        fns_3_args = {"map2"}
1056*da0073e9SAndroid Build Coastguard Worker        fns_value_kwarg = {"addcdiv", "addcmul"}
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker        (dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
1059*da0073e9SAndroid Build Coastguard Worker        full1d = torch.randn(*dims_full, device=device).flatten().float()
1060*da0073e9SAndroid Build Coastguard Worker        small = torch.randn(*dims_small, device=device).float()
1061*da0073e9SAndroid Build Coastguard Worker        large = torch.randn(*dims_large, device=device).float()
1062*da0073e9SAndroid Build Coastguard Worker        small_expanded = small.expand(*dims_full)
1063*da0073e9SAndroid Build Coastguard Worker        large_expanded = large.expand(*dims_full)
1064*da0073e9SAndroid Build Coastguard Worker        small2 = None
1065*da0073e9SAndroid Build Coastguard Worker        small2_expanded = None
1066*da0073e9SAndroid Build Coastguard Worker        if fn in fns_3_args or fn in fns_value_kwarg:
1067*da0073e9SAndroid Build Coastguard Worker            # create another smaller tensor
1068*da0073e9SAndroid Build Coastguard Worker            (dims_small2, _, _) = self._select_broadcastable_dims(dims_full)
1069*da0073e9SAndroid Build Coastguard Worker            small2 = torch.randn(*dims_small2, device=device).float()
1070*da0073e9SAndroid Build Coastguard Worker            small2_expanded = small2.expand(*dims_full)
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker        if small.is_cuda and fn in ['map', 'map2']:
1073*da0073e9SAndroid Build Coastguard Worker            # map and map2 are not implementd on CUDA tensors
1074*da0073e9SAndroid Build Coastguard Worker            return
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        if hasattr(large_expanded, fn):
1077*da0073e9SAndroid Build Coastguard Worker            # run through tensor versions of functions
1078*da0073e9SAndroid Build Coastguard Worker            # and verify fully expanded inputs give same results
1079*da0073e9SAndroid Build Coastguard Worker            expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker            def tensorfn(myfn, t1, t2):
1082*da0073e9SAndroid Build Coastguard Worker                if fn == "lerp":
1083*da0073e9SAndroid Build Coastguard Worker                    return myfn(t1, 0.5)
1084*da0073e9SAndroid Build Coastguard Worker                elif fn == "masked_select":
1085*da0073e9SAndroid Build Coastguard Worker                    return myfn(t1 < 0)
1086*da0073e9SAndroid Build Coastguard Worker                elif fn == "masked_scatter":
1087*da0073e9SAndroid Build Coastguard Worker                    return myfn(t1 < 0.5, full1d)
1088*da0073e9SAndroid Build Coastguard Worker                elif fn == "masked_fill":
1089*da0073e9SAndroid Build Coastguard Worker                    return myfn(t1 < 0.5, 1.0)
1090*da0073e9SAndroid Build Coastguard Worker                elif fn in fns_3_args:
1091*da0073e9SAndroid Build Coastguard Worker                    return myfn(1, t1, t2)
1092*da0073e9SAndroid Build Coastguard Worker                elif fn in fns_value_kwarg:
1093*da0073e9SAndroid Build Coastguard Worker                    return myfn(t1, t2, value=1)
1094*da0073e9SAndroid Build Coastguard Worker                else:
1095*da0073e9SAndroid Build Coastguard Worker                    return myfn(t1)
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker            # test various orders
1098*da0073e9SAndroid Build Coastguard Worker            for first, second, third in [(large, small, small2), (small, large, small2),
1099*da0073e9SAndroid Build Coastguard Worker                                         (small2, small, large), (small2, large, small)]:
1100*da0073e9SAndroid Build Coastguard Worker                if first is None:
1101*da0073e9SAndroid Build Coastguard Worker                    break  # ignore last iter when small2 is None
1102*da0073e9SAndroid Build Coastguard Worker                method_expanded = getattr(expanded[first], fn)
1103*da0073e9SAndroid Build Coastguard Worker                method = getattr(first, fn)
1104*da0073e9SAndroid Build Coastguard Worker                r1 = tensorfn(method_expanded, expanded[second], expanded[third])
1105*da0073e9SAndroid Build Coastguard Worker                r2 = tensorfn(method, second, third)
1106*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r1, r2)
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker        # now for torch. versions of functions
1109*da0073e9SAndroid Build Coastguard Worker        if hasattr(torch, fn):
1110*da0073e9SAndroid Build Coastguard Worker            fntorch = getattr(torch, fn)
1111*da0073e9SAndroid Build Coastguard Worker            expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker            def torchfn(t1, t2, t3):
1114*da0073e9SAndroid Build Coastguard Worker                if fn == "lerp":
1115*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, t2, 0.5)
1116*da0073e9SAndroid Build Coastguard Worker                elif fn == "masked_select":
1117*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, t2 < 0)
1118*da0073e9SAndroid Build Coastguard Worker                elif fn == "masked_scatter":
1119*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, t2 < 0.5, full1d)
1120*da0073e9SAndroid Build Coastguard Worker                elif fn == "masked_fill":
1121*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, t2 < 0.5, 1.0)
1122*da0073e9SAndroid Build Coastguard Worker                elif fn in fns_3_args:
1123*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, 1.0, t2, t3)
1124*da0073e9SAndroid Build Coastguard Worker                elif fn in fns_value_kwarg:
1125*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, t2, t3, value=1.0)
1126*da0073e9SAndroid Build Coastguard Worker                else:
1127*da0073e9SAndroid Build Coastguard Worker                    return fntorch(t1, t2)
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker            # test various orders
1130*da0073e9SAndroid Build Coastguard Worker            for first, second, third in [(large, small, small2), (small, large, small2),
1131*da0073e9SAndroid Build Coastguard Worker                                         (small2, small, large), (small2, large, small)]:
1132*da0073e9SAndroid Build Coastguard Worker                if first is None:
1133*da0073e9SAndroid Build Coastguard Worker                    break  # ignore last iter when small2 is None
1134*da0073e9SAndroid Build Coastguard Worker                r1 = torchfn(expanded[first], expanded[second], expanded[third])
1135*da0073e9SAndroid Build Coastguard Worker                r2 = torchfn(first, second, third)
1136*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r1, r2)
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker        # now for in place functions
1139*da0073e9SAndroid Build Coastguard Worker        # in-place tensor is not broadcastable; test only guaranteed
1140*da0073e9SAndroid Build Coastguard Worker        # to work by broadcasting other argument(s)
1141*da0073e9SAndroid Build Coastguard Worker        if not hasattr(large_expanded, fn + "_"):
1142*da0073e9SAndroid Build Coastguard Worker            return
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker        # need to clone largeExpanded so we can reuse, since functions are in-place
1145*da0073e9SAndroid Build Coastguard Worker        large_expanded_clone = large_expanded.clone()
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker        def tensorfn_inplace(t0, t1, t2=None):
1148*da0073e9SAndroid Build Coastguard Worker            t0_fn = getattr(t0, fn + "_")
1149*da0073e9SAndroid Build Coastguard Worker            if fn == "lerp":
1150*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1, 0.5)
1151*da0073e9SAndroid Build Coastguard Worker            elif fn == "masked_scatter":
1152*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1 < 0.5, full1d)
1153*da0073e9SAndroid Build Coastguard Worker            elif fn == "masked_fill":
1154*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1 < 0.5, 1.0)
1155*da0073e9SAndroid Build Coastguard Worker            elif fn == "map":
1156*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1, lambda x, y: x + y)
1157*da0073e9SAndroid Build Coastguard Worker            elif fn == "map2":
1158*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1, t2, lambda x, y, z: x + y + z)
1159*da0073e9SAndroid Build Coastguard Worker            elif fn in fns_3_args:
1160*da0073e9SAndroid Build Coastguard Worker                return t0_fn(1.0, t1, t2)
1161*da0073e9SAndroid Build Coastguard Worker            elif fn in fns_value_kwarg:
1162*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1, t2, value=1.0)
1163*da0073e9SAndroid Build Coastguard Worker            else:
1164*da0073e9SAndroid Build Coastguard Worker                return t0_fn(t1)
1165*da0073e9SAndroid Build Coastguard Worker        # in-place pointwise operations don't actually work if the in-place
1166*da0073e9SAndroid Build Coastguard Worker        # tensor is 0-strided (numpy has the same issue)
1167*da0073e9SAndroid Build Coastguard Worker        if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()):
1168*da0073e9SAndroid Build Coastguard Worker            r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
1169*da0073e9SAndroid Build Coastguard Worker            r2 = tensorfn_inplace(large_expanded_clone, small, small2)
1170*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r1, r2)
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker        def broadcastable(t0, t1, t2=None):
1173*da0073e9SAndroid Build Coastguard Worker            try:
1174*da0073e9SAndroid Build Coastguard Worker                t1.expand_as(t0)
1175*da0073e9SAndroid Build Coastguard Worker                if t2 is not None:
1176*da0073e9SAndroid Build Coastguard Worker                    t2.expand_as(t0)
1177*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
1178*da0073e9SAndroid Build Coastguard Worker                return False
1179*da0073e9SAndroid Build Coastguard Worker            return True
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker        def _test_in_place_broadcastable(t0, t1, t2=None):
1182*da0073e9SAndroid Build Coastguard Worker            if not broadcastable(t0, t1, t2):
1183*da0073e9SAndroid Build Coastguard Worker                same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True)
1184*da0073e9SAndroid Build Coastguard Worker                if not same_size:
1185*da0073e9SAndroid Build Coastguard Worker                    # Functionalization converts the inplace to an out-of-place, which causes us to error.
1186*da0073e9SAndroid Build Coastguard Worker                    # We should fix this, but "error probably on bad inputs" isn't a hi-pri PT2 item.
1187*da0073e9SAndroid Build Coastguard Worker                    if not TEST_WITH_TORCHINDUCTOR:
1188*da0073e9SAndroid Build Coastguard Worker                        self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2))
1189*da0073e9SAndroid Build Coastguard Worker            else:
1190*da0073e9SAndroid Build Coastguard Worker                tensorfn_inplace(t0, t1, t2)
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker        if fn not in fns_3_args and fn not in fns_value_kwarg:
1193*da0073e9SAndroid Build Coastguard Worker            _test_in_place_broadcastable(small, large_expanded)
1194*da0073e9SAndroid Build Coastguard Worker            _test_in_place_broadcastable(small, large)
1195*da0073e9SAndroid Build Coastguard Worker        else:
1196*da0073e9SAndroid Build Coastguard Worker            _test_in_place_broadcastable(small2, small_expanded, large_expanded)
1197*da0073e9SAndroid Build Coastguard Worker            _test_in_place_broadcastable(small2, small, large)
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
1200*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1201*da0073e9SAndroid Build Coastguard Worker    @wrapDeterministicFlagAPITest
1202*da0073e9SAndroid Build Coastguard Worker    def test_cublas_config_nondeterministic_alert(self, device):
1203*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1204*da0073e9SAndroid Build Coastguard Worker            # (function, (tensor sizes))
1205*da0073e9SAndroid Build Coastguard Worker            ('mm', ((2, 2), (2, 2),)),
1206*da0073e9SAndroid Build Coastguard Worker            ('mv', ((2, 2), (2,),)),
1207*da0073e9SAndroid Build Coastguard Worker            ('bmm', ((1, 2, 2), (1, 2, 2),))]
1208*da0073e9SAndroid Build Coastguard Worker
1209*da0073e9SAndroid Build Coastguard Worker        test_configs = [
1210*da0073e9SAndroid Build Coastguard Worker            # (CuBLAS workspace config, is deterministic)
1211*da0073e9SAndroid Build Coastguard Worker            ('garbage', False),
1212*da0073e9SAndroid Build Coastguard Worker            (None, False),
1213*da0073e9SAndroid Build Coastguard Worker            (':4096:8', True),
1214*da0073e9SAndroid Build Coastguard Worker            (':16:8', True)]
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
1217*da0073e9SAndroid Build Coastguard Worker        is_cuda10_2_or_higher = (
1218*da0073e9SAndroid Build Coastguard Worker            (torch.version.cuda is not None)
1219*da0073e9SAndroid Build Coastguard Worker            and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        def test_case_info(fn_name, config):
1222*da0073e9SAndroid Build Coastguard Worker            return f'function "{fn_name}" with config "{"" if config is None else config}"'
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker        # Create processes to test each combination of test cases and config settings
1225*da0073e9SAndroid Build Coastguard Worker        processes = []
1226*da0073e9SAndroid Build Coastguard Worker        for fn_name, arg_sizes in test_cases:
1227*da0073e9SAndroid Build Coastguard Worker            for config, is_config_deterministic in test_configs:
1228*da0073e9SAndroid Build Coastguard Worker                env = os.environ.copy()
1229*da0073e9SAndroid Build Coastguard Worker                if config is None:
1230*da0073e9SAndroid Build Coastguard Worker                    if env.get(cublas_var_name) is not None:
1231*da0073e9SAndroid Build Coastguard Worker                        del env[cublas_var_name]
1232*da0073e9SAndroid Build Coastguard Worker                else:
1233*da0073e9SAndroid Build Coastguard Worker                    env[cublas_var_name] = config
1234*da0073e9SAndroid Build Coastguard Worker                should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic
1235*da0073e9SAndroid Build Coastguard Worker                script = f"""
1236*da0073e9SAndroid Build Coastguard Workerimport torch
1237*da0073e9SAndroid Build Coastguard Workertorch.use_deterministic_algorithms(True)
1238*da0073e9SAndroid Build Coastguard Workerfn = torch.{fn_name}
1239*da0073e9SAndroid Build Coastguard Workerarg_sizes = {arg_sizes}
1240*da0073e9SAndroid Build Coastguard Workerdevice = '{device}'
1241*da0073e9SAndroid Build Coastguard Workershould_throw_error = {should_throw_error}
1242*da0073e9SAndroid Build Coastguard Workerargs = []
1243*da0073e9SAndroid Build Coastguard Workerfor arg_size in arg_sizes:
1244*da0073e9SAndroid Build Coastguard Worker    args.append(torch.randn(*arg_size, device=device))
1245*da0073e9SAndroid Build Coastguard Workertry:
1246*da0073e9SAndroid Build Coastguard Worker    fn(*args)
1247*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e:
1248*da0073e9SAndroid Build Coastguard Worker    if not should_throw_error:
1249*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError('Did not expect any error to be raised')
1250*da0073e9SAndroid Build Coastguard Worker    elif 'Deterministic behavior was enabled with either' not in str(e):
1251*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error')
1252*da0073e9SAndroid Build Coastguard Workerelse:
1253*da0073e9SAndroid Build Coastguard Worker    if should_throw_error:
1254*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised')
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker"""
1257*da0073e9SAndroid Build Coastguard Worker                try:
1258*da0073e9SAndroid Build Coastguard Worker                    subprocess.check_output(
1259*da0073e9SAndroid Build Coastguard Worker                        [sys.executable, '-c', script],
1260*da0073e9SAndroid Build Coastguard Worker                        stderr=subprocess.STDOUT,
1261*da0073e9SAndroid Build Coastguard Worker                        # On Windows, opening the subprocess with the default CWD makes `import torch`
1262*da0073e9SAndroid Build Coastguard Worker                        # fail, so just set CWD to this script's directory
1263*da0073e9SAndroid Build Coastguard Worker                        cwd=os.path.dirname(os.path.realpath(__file__)),
1264*da0073e9SAndroid Build Coastguard Worker                        env=env)
1265*da0073e9SAndroid Build Coastguard Worker                except subprocess.CalledProcessError as e:
1266*da0073e9SAndroid Build Coastguard Worker                    self.fail(msg=(
1267*da0073e9SAndroid Build Coastguard Worker                        f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n'
1268*da0073e9SAndroid Build Coastguard Worker                        + e.output.decode("utf-8")))
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1271*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1272*da0073e9SAndroid Build Coastguard Worker    @dtypes(*get_all_qint_dtypes())
1273*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_resize_quantized(self, device, dtype):
1274*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-1, 0, 1, 2, 3], dtype=torch.float, device=device)
1275*da0073e9SAndroid Build Coastguard Worker        b = torch.quantize_per_tensor(a, 0.1, 10, dtype)
1276*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1277*da0073e9SAndroid Build Coastguard Worker            lambda: b.resize_((10,)),
1278*da0073e9SAndroid Build Coastguard Worker            'quantized_resize_cpu_')
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker    @skipXLA
1281*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1282*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
1283*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_resize(self, device, dtype):
1284*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1285*da0073e9SAndroid Build Coastguard Worker            # size, stride, resize_size
1286*da0073e9SAndroid Build Coastguard Worker            ((10,), (1,), (5,)),
1287*da0073e9SAndroid Build Coastguard Worker            ((10,), (0,), (10,)),
1288*da0073e9SAndroid Build Coastguard Worker            ((10,), (1,), (20,)),
1289*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), None, (2, 3, 4)),
1290*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), None, (6, 3, 4)),
1291*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), None, (2, 5, 4)),
1292*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), None, (2, 3, 6)),
1293*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), None, (3, 4, 5)),
1294*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), (1, 4, 12), (2, 3, 4)),
1295*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), (1, 4, 12), (4, 3, 4)),
1296*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), (1, 4, 12), (2, 4, 4)),
1297*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), (1, 4, 12), (2, 3, 5)),
1298*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), (1, 4, 12), (3, 4, 5)),
1299*da0073e9SAndroid Build Coastguard Worker            ((2, 3, 4), (1, 0, 1), (2, 4, 5)),
1300*da0073e9SAndroid Build Coastguard Worker        ]
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker        for size, stride, resize_size in test_cases:
1303*da0073e9SAndroid Build Coastguard Worker            if stride is None:
1304*da0073e9SAndroid Build Coastguard Worker                a = torch.zeros(size, dtype=dtype, device=device)
1305*da0073e9SAndroid Build Coastguard Worker            else:
1306*da0073e9SAndroid Build Coastguard Worker                a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0)
1307*da0073e9SAndroid Build Coastguard Worker            old_storage = a.untyped_storage().clone()
1308*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(True, fill_uninitialized_memory=True):
1309*da0073e9SAndroid Build Coastguard Worker                a.resize_(resize_size)
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker            new_storage = a.untyped_storage()
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker            # If storage size was increased, check that the new section is
1314*da0073e9SAndroid Build Coastguard Worker            # filled with NaN/MAX_INT. Otherwise, check that the storages are
1315*da0073e9SAndroid Build Coastguard Worker            # equal.
1316*da0073e9SAndroid Build Coastguard Worker            old_tensor = torch.tensor(old_storage, dtype=dtype)
1317*da0073e9SAndroid Build Coastguard Worker            old_numel = old_tensor.numel()
1318*da0073e9SAndroid Build Coastguard Worker            new_tensor = torch.tensor(new_storage, dtype=dtype)
1319*da0073e9SAndroid Build Coastguard Worker            new_numel = new_tensor.numel()
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker            if new_numel > old_numel:
1322*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(new_tensor[:old_numel], old_tensor)
1323*da0073e9SAndroid Build Coastguard Worker                fill_section = new_tensor[old_numel:]
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker                if dtype.is_floating_point or dtype.is_complex:
1326*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(fill_section.isnan().all())
1327*da0073e9SAndroid Build Coastguard Worker                else:
1328*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.bool:
1329*da0073e9SAndroid Build Coastguard Worker                        max_val = True
1330*da0073e9SAndroid Build Coastguard Worker                    else:
1331*da0073e9SAndroid Build Coastguard Worker                        max_val = torch.iinfo(dtype).max
1332*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(fill_section.eq(max_val).all())
1333*da0073e9SAndroid Build Coastguard Worker            else:
1334*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(old_tensor, new_tensor)
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker    # When deterministic algorithms are enabled, `torch.empty` should fill floating
1337*da0073e9SAndroid Build Coastguard Worker    # point tensors with NaN and integer tensors with MAX_INT
1338*da0073e9SAndroid Build Coastguard Worker    @skipXLA
1339*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1340*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
1341*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_empty(self, device, dtype):
1342*da0073e9SAndroid Build Coastguard Worker        gen_fns = [
1343*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty(10, 9, device=device, dtype=dtype),
1344*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
1345*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
1346*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
1347*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
1348*da0073e9SAndroid Build Coastguard Worker            lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
1349*da0073e9SAndroid Build Coastguard Worker        ]
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker        for gen_fn in gen_fns:
1352*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(True, fill_uninitialized_memory=True):
1353*da0073e9SAndroid Build Coastguard Worker                res = gen_fn()
1354*da0073e9SAndroid Build Coastguard Worker
1355*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
1356*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(res.isnan().all())
1357*da0073e9SAndroid Build Coastguard Worker            else:
1358*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.bool:
1359*da0073e9SAndroid Build Coastguard Worker                    max_val = True
1360*da0073e9SAndroid Build Coastguard Worker                else:
1361*da0073e9SAndroid Build Coastguard Worker                    max_val = torch.iinfo(dtype).max
1362*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(res.eq(max_val).all())
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker    # FIXME: update OpInfos to support "nondeterministic samples" and port these tests
1365*da0073e9SAndroid Build Coastguard Worker    #   to that architecture
1366*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1367*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1368*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_AvgPool3d(self, device):
1369*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.AvgPool3d(3)
1370*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1371*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1372*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1375*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1376*da0073e9SAndroid Build Coastguard Worker            'avg_pool3d_backward_cuda',
1377*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1380*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1381*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_AdaptiveAvgPool2d(self, device):
1382*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.AdaptiveAvgPool2d(3)
1383*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, requires_grad=True, device=device)
1384*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1385*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1388*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1389*da0073e9SAndroid Build Coastguard Worker            'adaptive_avg_pool2d_backward_cuda',
1390*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1393*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1394*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_AdaptiveAvgPool3d(self, device):
1395*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.AdaptiveAvgPool3d(3)
1396*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1397*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1398*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1401*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1402*da0073e9SAndroid Build Coastguard Worker            'adaptive_avg_pool3d_backward_cuda',
1403*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1406*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1407*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_MaxPool3d(self, device):
1408*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.MaxPool3d(3)
1409*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1410*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1411*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1414*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1415*da0073e9SAndroid Build Coastguard Worker            'max_pool3d_with_indices_backward_cuda',
1416*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1417*da0073e9SAndroid Build Coastguard Worker
1418*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1419*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1420*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_AdaptiveMaxPool2d(self, device):
1421*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.AdaptiveMaxPool2d(3)
1422*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, requires_grad=True, device=device)
1423*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1424*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1427*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1428*da0073e9SAndroid Build Coastguard Worker            'adaptive_max_pool2d_backward_cuda',
1429*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1432*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1433*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_FractionalMaxPool2d(self, device):
1434*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.FractionalMaxPool2d(2, output_ratio=0.5)
1435*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1436*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1437*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1440*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1441*da0073e9SAndroid Build Coastguard Worker            'fractional_max_pool2d_backward_cuda',
1442*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1445*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1446*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_FractionalMaxPool3d(self, device):
1447*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.FractionalMaxPool3d(2, output_ratio=0.5)
1448*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 3, 3, 3, requires_grad=True, device=device)
1449*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1450*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1453*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1454*da0073e9SAndroid Build Coastguard Worker            'fractional_max_pool3d_backward_cuda',
1455*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
1458*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1459*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_MaxUnpool1d(self, device, dtype):
1460*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half and torch.device(device).type == 'cpu':
1461*da0073e9SAndroid Build Coastguard Worker            self.skipTest('float16 not implemented on CPU')
1462*da0073e9SAndroid Build Coastguard Worker
1463*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.MaxUnpool1d(3, 1)
1464*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 1, 7, dtype=dtype, device=device)
1465*da0073e9SAndroid Build Coastguard Worker        indices = torch.zeros_like(input, dtype=torch.long, device=device)
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1468*da0073e9SAndroid Build Coastguard Worker            lambda: module(input, indices),
1469*da0073e9SAndroid Build Coastguard Worker            'max_unpooling2d_forward_out')
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
1472*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1473*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_MaxUnpool2d(self, device, dtype):
1474*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half and torch.device(device).type == 'cpu':
1475*da0073e9SAndroid Build Coastguard Worker            self.skipTest('float16 not implemented on CPU')
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.MaxUnpool2d(3, 1)
1478*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 1, 7, 7, dtype=dtype, device=device)
1479*da0073e9SAndroid Build Coastguard Worker        indices = torch.zeros_like(input, dtype=torch.long, device=device)
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1482*da0073e9SAndroid Build Coastguard Worker            lambda: module(input, indices),
1483*da0073e9SAndroid Build Coastguard Worker            'max_unpooling2d_forward_out')
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
1486*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1487*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_MaxUnpool3d(self, device, dtype):
1488*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half and torch.device(device).type == 'cpu':
1489*da0073e9SAndroid Build Coastguard Worker            self.skipTest('float16 not implemented on CPU')
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.MaxUnpool3d(3, 1)
1492*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 1, 7, 7, 7, dtype=dtype, device=device)
1493*da0073e9SAndroid Build Coastguard Worker        indices = torch.zeros_like(input, dtype=torch.long, device=device)
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1496*da0073e9SAndroid Build Coastguard Worker            lambda: module(input, indices),
1497*da0073e9SAndroid Build Coastguard Worker            'max_unpooling3d_forward_out')
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1500*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1501*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_interpolate_linear(self, device):
1502*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 2, 4, device=device, requires_grad=True)
1503*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.interpolate(
1504*da0073e9SAndroid Build Coastguard Worker            input,
1505*da0073e9SAndroid Build Coastguard Worker            size=12,
1506*da0073e9SAndroid Build Coastguard Worker            mode='linear',
1507*da0073e9SAndroid Build Coastguard Worker            align_corners=False)
1508*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1511*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad),
1512*da0073e9SAndroid Build Coastguard Worker            'upsample_linear1d_backward_out_cuda',
1513*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1516*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_interpolate_bilinear(self, device):
1517*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1518*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.interpolate(
1519*da0073e9SAndroid Build Coastguard Worker            input,
1520*da0073e9SAndroid Build Coastguard Worker            size=12,
1521*da0073e9SAndroid Build Coastguard Worker            mode='bilinear',
1522*da0073e9SAndroid Build Coastguard Worker            align_corners=False)
1523*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1526*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad),
1527*da0073e9SAndroid Build Coastguard Worker            'upsample_bilinear2d_backward_out_cuda',
1528*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("aot-autograd issue")
1531*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_replication_pad2d(self, device):
1532*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1533*da0073e9SAndroid Build Coastguard Worker            # size, padding
1534*da0073e9SAndroid Build Coastguard Worker            [(1, 2, 4, 4), (0, 0, 0, 0)],
1535*da0073e9SAndroid Build Coastguard Worker            [(1, 2, 4, 4), (3, 4, 5, 6)],
1536*da0073e9SAndroid Build Coastguard Worker            [(3, 8, 7), (0, 0, 0, 0)],
1537*da0073e9SAndroid Build Coastguard Worker            [(3, 8, 7), (4, 3, 2, 7)],
1538*da0073e9SAndroid Build Coastguard Worker        ]
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type != 'xla':
1541*da0073e9SAndroid Build Coastguard Worker            test_cases += [
1542*da0073e9SAndroid Build Coastguard Worker                [(4, 3, 5, 10), (-9, 4, 5, 6)],
1543*da0073e9SAndroid Build Coastguard Worker                [(3, 8, 7), (-4, -2, -2, -3)],
1544*da0073e9SAndroid Build Coastguard Worker            ]
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker        for size, padding in test_cases:
1547*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*size, device=device, requires_grad=True)
1548*da0073e9SAndroid Build Coastguard Worker            grad = None
1549*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(True):
1550*da0073e9SAndroid Build Coastguard Worker                res = torch.nn.functional.pad(
1551*da0073e9SAndroid Build Coastguard Worker                    input,
1552*da0073e9SAndroid Build Coastguard Worker                    padding,
1553*da0073e9SAndroid Build Coastguard Worker                    mode='replicate')
1554*da0073e9SAndroid Build Coastguard Worker                res.backward(torch.ones_like(res))
1555*da0073e9SAndroid Build Coastguard Worker                if grad is None:
1556*da0073e9SAndroid Build Coastguard Worker                    grad = input.grad
1557*da0073e9SAndroid Build Coastguard Worker                else:
1558*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad, input.grad, atol=0, rtol=0)
1559*da0073e9SAndroid Build Coastguard Worker                input.grad = None
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1562*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_interpolate_bilinear(self, device):
1563*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1564*da0073e9SAndroid Build Coastguard Worker        grad = None
1565*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(True):
1566*da0073e9SAndroid Build Coastguard Worker            for _ in range(5):
1567*da0073e9SAndroid Build Coastguard Worker                res = torch.nn.functional.interpolate(
1568*da0073e9SAndroid Build Coastguard Worker                    input,
1569*da0073e9SAndroid Build Coastguard Worker                    size=12,
1570*da0073e9SAndroid Build Coastguard Worker                    mode='bilinear',
1571*da0073e9SAndroid Build Coastguard Worker                    align_corners=False)
1572*da0073e9SAndroid Build Coastguard Worker                res.backward(torch.ones_like(res))
1573*da0073e9SAndroid Build Coastguard Worker                if grad is None:
1574*da0073e9SAndroid Build Coastguard Worker                    grad = input.grad
1575*da0073e9SAndroid Build Coastguard Worker                else:
1576*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad, input.grad, atol=0, rtol=0)
1577*da0073e9SAndroid Build Coastguard Worker                input.grad = None
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1580*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1581*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_interpolate_bicubic(self, device):
1582*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1583*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.interpolate(
1584*da0073e9SAndroid Build Coastguard Worker            input,
1585*da0073e9SAndroid Build Coastguard Worker            size=12,
1586*da0073e9SAndroid Build Coastguard Worker            mode='bicubic',
1587*da0073e9SAndroid Build Coastguard Worker            align_corners=False)
1588*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1591*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad),
1592*da0073e9SAndroid Build Coastguard Worker            'upsample_bicubic2d_backward_out_cuda',
1593*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1596*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1597*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_interpolate_trilinear(self, device):
1598*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 2, 4, 4, 4, device=device, requires_grad=True)
1599*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.interpolate(
1600*da0073e9SAndroid Build Coastguard Worker            input,
1601*da0073e9SAndroid Build Coastguard Worker            size=12,
1602*da0073e9SAndroid Build Coastguard Worker            mode='trilinear',
1603*da0073e9SAndroid Build Coastguard Worker            align_corners=False)
1604*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1607*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad),
1608*da0073e9SAndroid Build Coastguard Worker            'upsample_trilinear3d_backward_out_cuda',
1609*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1612*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1613*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_ReflectionPad1d(self, device):
1614*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.ReflectionPad1d((1, 2))
1615*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 8, device=device, requires_grad=True)
1616*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1617*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1620*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1621*da0073e9SAndroid Build Coastguard Worker            'reflection_pad1d_backward_out_cuda',
1622*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1625*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_ReflectionPad2d(self, device):
1626*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.ReflectionPad2d((1, 2, 3, 4))
1627*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 8, 8, device=device, requires_grad=True)
1628*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1629*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1630*da0073e9SAndroid Build Coastguard Worker
1631*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1632*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1633*da0073e9SAndroid Build Coastguard Worker            'reflection_pad2d_backward_cuda',
1634*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1637*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1638*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_ReflectionPad3d(self, device):
1639*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.ReflectionPad3d((1, 2, 3, 4, 5, 6))
1640*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 8, 8, 8, device=device, requires_grad=True)
1641*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1642*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1645*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1646*da0073e9SAndroid Build Coastguard Worker            'reflection_pad3d_backward_out_cuda',
1647*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1650*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1651*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_ReplicationPad1d(self, device):
1652*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.ReplicationPad1d((1, 2))
1653*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 4, device=device, requires_grad=True)
1654*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1655*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1658*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1659*da0073e9SAndroid Build Coastguard Worker            'replication_pad1d_backward_cuda',
1660*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1663*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_ReplicationPad2d(self, device):
1664*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.ReplicationPad2d((1, 2, 3, 4))
1665*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 4, 4, device=device, requires_grad=True)
1666*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1667*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker        # Nondeterministic alert should only be raised if the forward call was
1670*da0073e9SAndroid Build Coastguard Worker        # nondeterministic
1671*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1672*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1673*da0073e9SAndroid Build Coastguard Worker            'replication_pad2d_backward_cuda',
1674*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(True):
1677*da0073e9SAndroid Build Coastguard Worker            res = module(input)
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker        # If the forward call was deterministic, nondeterministic alert should
1682*da0073e9SAndroid Build Coastguard Worker        # not be raised
1683*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1684*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1685*da0073e9SAndroid Build Coastguard Worker            'replication_pad2d_backward_cuda',
1686*da0073e9SAndroid Build Coastguard Worker            False)
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1689*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1690*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_ReplicationPad3d(self, device):
1691*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.ReplicationPad3d((1, 2, 3, 4, 5, 6))
1692*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 4, 4, 4, device=device, requires_grad=True)
1693*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1694*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1697*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1698*da0073e9SAndroid Build Coastguard Worker            'replication_pad3d_backward_cuda',
1699*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Warning is not raised.")
1702*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_NLLLoss(self, device):
1703*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.NLLLoss()
1704*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, 5, 5, device=device)
1705*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(2, 5, 5, device=device).mul(3).floor().long()
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker
1708*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1709*da0073e9SAndroid Build Coastguard Worker            lambda: module(input, target),
1710*da0073e9SAndroid Build Coastguard Worker            'nll_loss2d_forward_out_cuda_template',
1711*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1714*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_CTCLoss(self, device):
1715*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.CTCLoss()
1716*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(50, 3, 15, device=device, requires_grad=True)
1717*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(0, 14, (3, 30), device=device)
1718*da0073e9SAndroid Build Coastguard Worker        input_lengths = [50, 50, 50]
1719*da0073e9SAndroid Build Coastguard Worker        target_lengths = [30, 25, 20]
1720*da0073e9SAndroid Build Coastguard Worker        res = module(input, target, input_lengths, target_lengths)
1721*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1724*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1725*da0073e9SAndroid Build Coastguard Worker            'ctc_loss_backward_gpu',
1726*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1729*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_EmbeddingBag_max(self, device):
1730*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.EmbeddingBag(
1731*da0073e9SAndroid Build Coastguard Worker            4, 3, None, 2., False, 'max',
1732*da0073e9SAndroid Build Coastguard Worker            _weight=torch.randn(4, 3, device=device, requires_grad=True))
1733*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(0, 3, (4, 3), device=device)
1734*da0073e9SAndroid Build Coastguard Worker        res = module(input)
1735*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1738*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1739*da0073e9SAndroid Build Coastguard Worker            'embedding_bag_backward_cuda_max',
1740*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool))
1743*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1744*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_cumsum(self, device, dtype):
1745*da0073e9SAndroid Build Coastguard Worker        input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
1746*da0073e9SAndroid Build Coastguard Worker        should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex)
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker        for op_call in [torch.Tensor.cumsum, torch.cumsum]:
1749*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1750*da0073e9SAndroid Build Coastguard Worker                lambda: op_call(input, 0),
1751*da0073e9SAndroid Build Coastguard Worker                'cumsum_cuda_kernel',
1752*da0073e9SAndroid Build Coastguard Worker                should_alert)
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # expected a non-determinitic error, but it was not raised
1755*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1756*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_put(self, device):
1757*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, device=device)
1758*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([0, 0], device=device)
1759*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([0., 1.], device=device)
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker        for op_call in [torch.Tensor.put, torch.Tensor.put_]:
1762*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1763*da0073e9SAndroid Build Coastguard Worker                lambda: op_call(a, indices, values, accumulate=False),
1764*da0073e9SAndroid Build Coastguard Worker                'put_')
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker    # warn_only=False correctly raises RuntimeError: put_ does not have a deterministic implementation
1767*da0073e9SAndroid Build Coastguard Worker    # warn_only=True logs warning from the FallbackKernel: torch.ops.aten.put_.default, instead of as UserWarning:
1768*da0073e9SAndroid Build Coastguard Worker    # [W Context.cpp:%(lineno)] Warning: put_ does not have a deterministic implementation
1769*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("warning is logged from the FallbackKernel: torch.ops.aten.put_.default when warn_only=True")
1770*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_put_accumulate(self, device):
1771*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, device=device)
1772*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([0, 0], device=device)
1773*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([0., 1.], device=device)
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker        for op_call in [torch.Tensor.put, torch.Tensor.put_]:
1776*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1777*da0073e9SAndroid Build Coastguard Worker                lambda: op_call(a, indices, values, accumulate=True),
1778*da0073e9SAndroid Build Coastguard Worker                'put_',
1779*da0073e9SAndroid Build Coastguard Worker                torch.device(device).type == 'cuda')
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1782*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_histc(self, device):
1783*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([], device=device)
1784*da0073e9SAndroid Build Coastguard Worker        for op_call in [torch.histc, torch.Tensor.histc]:
1785*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1786*da0073e9SAndroid Build Coastguard Worker                lambda: op_call(a, min=0, max=3),
1787*da0073e9SAndroid Build Coastguard Worker                '_histc_cuda',
1788*da0073e9SAndroid Build Coastguard Worker                torch.device(device).type == 'cuda')
1789*da0073e9SAndroid Build Coastguard Worker
1790*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1791*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_bincount(self, device):
1792*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([], device=device, dtype=torch.long)
1793*da0073e9SAndroid Build Coastguard Worker        weights = torch.tensor([], device=device)
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker        for op_call in [torch.bincount, torch.Tensor.bincount]:
1796*da0073e9SAndroid Build Coastguard Worker            # Error should only be raised when device is CUDA and weights are
1797*da0073e9SAndroid Build Coastguard Worker            # given
1798*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1799*da0073e9SAndroid Build Coastguard Worker                lambda: op_call(a, weights),
1800*da0073e9SAndroid Build Coastguard Worker                '_bincount_cuda',
1801*da0073e9SAndroid Build Coastguard Worker                torch.device(device).type == 'cuda')
1802*da0073e9SAndroid Build Coastguard Worker
1803*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1804*da0073e9SAndroid Build Coastguard Worker                lambda: op_call(a),
1805*da0073e9SAndroid Build Coastguard Worker                '_bincount_cuda',
1806*da0073e9SAndroid Build Coastguard Worker                False)
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker    # Ensures that kthvalue throws nondeterministic alerts in the correct cases
1809*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1810*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_kthvalue(self, device, dtype):
1811*da0073e9SAndroid Build Coastguard Worker        def test_func(call_type):
1812*da0073e9SAndroid Build Coastguard Worker            S = 10
1813*da0073e9SAndroid Build Coastguard Worker            k = 5
1814*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(S, device=device)
1815*da0073e9SAndroid Build Coastguard Worker            if call_type == 'function':
1816*da0073e9SAndroid Build Coastguard Worker                torch.kthvalue(a, k)
1817*da0073e9SAndroid Build Coastguard Worker            elif call_type == 'method':
1818*da0073e9SAndroid Build Coastguard Worker                a.kthvalue(k)
1819*da0073e9SAndroid Build Coastguard Worker            elif call_type == 'out':
1820*da0073e9SAndroid Build Coastguard Worker                values = torch.empty_like(a)
1821*da0073e9SAndroid Build Coastguard Worker                indices = torch.empty((), device=device, dtype=torch.long)
1822*da0073e9SAndroid Build Coastguard Worker                torch.kthvalue(a, k, out=(values, indices))
1823*da0073e9SAndroid Build Coastguard Worker            else:
1824*da0073e9SAndroid Build Coastguard Worker                self.fail(f"'{call_type}' is not a valid call type")
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker        for call_type in ['function', 'method', 'out']:
1827*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1828*da0073e9SAndroid Build Coastguard Worker                lambda: test_func('function'),
1829*da0073e9SAndroid Build Coastguard Worker                'kthvalue CUDA',
1830*da0073e9SAndroid Build Coastguard Worker                torch.device(device).type == 'cuda')
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1833*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1834*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_grid_sample_2d(self, device):
1835*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True)
1836*da0073e9SAndroid Build Coastguard Worker        grid = torch.empty(1, 1, 1, 2, device=device)
1837*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
1838*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1839*da0073e9SAndroid Build Coastguard Worker
1840*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1841*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1842*da0073e9SAndroid Build Coastguard Worker            'grid_sampler_2d_backward_cuda',
1843*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1844*da0073e9SAndroid Build Coastguard Worker
1845*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1846*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1847*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_grid_sample_3d(self, device):
1848*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(1, 1, 2, 2, 2, device=device, requires_grad=True)
1849*da0073e9SAndroid Build Coastguard Worker        grid = torch.empty(1, 1, 1, 2, 3, device=device)
1850*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
1851*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(res)
1852*da0073e9SAndroid Build Coastguard Worker
1853*da0073e9SAndroid Build Coastguard Worker        self.check_nondeterministic_alert(
1854*da0073e9SAndroid Build Coastguard Worker            lambda: res.backward(grad, retain_graph=True),
1855*da0073e9SAndroid Build Coastguard Worker            'grid_sampler_3d_backward_cuda',
1856*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type == 'cuda')
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker    def test_invalid_shapes_grid_sampler(self, device):
1859*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(
1860*da0073e9SAndroid Build Coastguard Worker            make_tensor, device=device, dtype=torch.float64, requires_grad=True)
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker        inputs = (
1863*da0073e9SAndroid Build Coastguard Worker            # input, grid
1864*da0073e9SAndroid Build Coastguard Worker            ((5, 5, 5, 5, 5,), (1, 1, 1, 4, 4,)),  # 3d
1865*da0073e9SAndroid Build Coastguard Worker            ((5, 5, 5, 5,), (1, 1, 4, 4,)),  # 2d
1866*da0073e9SAndroid Build Coastguard Worker        )
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker        interpolation_mode = 0
1869*da0073e9SAndroid Build Coastguard Worker        padding_mode = 0
1870*da0073e9SAndroid Build Coastguard Worker        align_corners = True
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker        err = "expected grid and input to have same batch size"
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker        for input, grid in inputs:
1875*da0073e9SAndroid Build Coastguard Worker            input = make_arg(input)
1876*da0073e9SAndroid Build Coastguard Worker            grid = make_arg(grid, low=-1, high=1)
1877*da0073e9SAndroid Build Coastguard Worker
1878*da0073e9SAndroid Build Coastguard Worker            # Wrapper for the 2d, 3d, and cuDNN functions listed below.
1879*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err):
1880*da0073e9SAndroid Build Coastguard Worker                torch.grid_sampler(
1881*da0073e9SAndroid Build Coastguard Worker                    input, grid, interpolation_mode, padding_mode,
1882*da0073e9SAndroid Build Coastguard Worker                    align_corners)
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker            # Expects 2d input.
1885*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err):
1886*da0073e9SAndroid Build Coastguard Worker                torch.grid_sampler_2d(
1887*da0073e9SAndroid Build Coastguard Worker                    input, grid, interpolation_mode, padding_mode,
1888*da0073e9SAndroid Build Coastguard Worker                    align_corners)
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker            # Expects 3d input.
1891*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err):
1892*da0073e9SAndroid Build Coastguard Worker                torch.grid_sampler_3d(
1893*da0073e9SAndroid Build Coastguard Worker                    input, grid, interpolation_mode, padding_mode,
1894*da0073e9SAndroid Build Coastguard Worker                    align_corners)
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker            # Expects 2d input.
1897*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err):
1898*da0073e9SAndroid Build Coastguard Worker                torch._grid_sampler_2d_cpu_fallback(
1899*da0073e9SAndroid Build Coastguard Worker                    input, grid, interpolation_mode, padding_mode,
1900*da0073e9SAndroid Build Coastguard Worker                    align_corners)
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker            # Expects 2d input, on CUDA.
1903*da0073e9SAndroid Build Coastguard Worker            # Doesn't work on CPU and ROCm.
1904*da0073e9SAndroid Build Coastguard Worker            if device != 'cpu' and TEST_CUDNN and not TEST_WITH_ROCM:
1905*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, err):
1906*da0073e9SAndroid Build Coastguard Worker                    torch.cudnn_grid_sampler(input, grid)
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker    def test_dist(self, device):
1909*da0073e9SAndroid Build Coastguard Worker        def run_test(x, y):
1910*da0073e9SAndroid Build Coastguard Worker            for p in [0, 1, 2, 3, 4, inf, -inf]:
1911*da0073e9SAndroid Build Coastguard Worker                dist_xy = torch.dist(x, y, p)
1912*da0073e9SAndroid Build Coastguard Worker                dist_xy_norm = torch.norm(x - y, p)
1913*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dist_xy, dist_xy_norm)
1914*da0073e9SAndroid Build Coastguard Worker
1915*da0073e9SAndroid Build Coastguard Worker        run_test(torch.randn(5, device=device), torch.randn(5, device=device))
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(3, device=device)
1918*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros(3, device=device)
1919*da0073e9SAndroid Build Coastguard Worker        y[1] = 1.
1920*da0073e9SAndroid Build Coastguard Worker        run_test(x, y)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker    # Ensures that median throws nondeterministic alerts in the correct cases
1923*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1924*da0073e9SAndroid Build Coastguard Worker    def test_nondeterministic_alert_median(self, device, dtype):
1925*da0073e9SAndroid Build Coastguard Worker        def test_func(call_type):
1926*da0073e9SAndroid Build Coastguard Worker            S = 10
1927*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(S, device=device)
1928*da0073e9SAndroid Build Coastguard Worker            if call_type == 'function':
1929*da0073e9SAndroid Build Coastguard Worker                torch.median(a)
1930*da0073e9SAndroid Build Coastguard Worker            elif call_type == 'function with indices':
1931*da0073e9SAndroid Build Coastguard Worker                torch.median(a, 0)
1932*da0073e9SAndroid Build Coastguard Worker            elif call_type == 'method':
1933*da0073e9SAndroid Build Coastguard Worker                a.median()
1934*da0073e9SAndroid Build Coastguard Worker            elif call_type == 'method with indices':
1935*da0073e9SAndroid Build Coastguard Worker                a.median(0)
1936*da0073e9SAndroid Build Coastguard Worker            elif call_type == 'out with indices':
1937*da0073e9SAndroid Build Coastguard Worker                result = torch.empty_like(a)
1938*da0073e9SAndroid Build Coastguard Worker                indices = torch.empty((), dtype=torch.long, device=device)
1939*da0073e9SAndroid Build Coastguard Worker                torch.median(a, 0, out=(result, indices))
1940*da0073e9SAndroid Build Coastguard Worker            else:
1941*da0073e9SAndroid Build Coastguard Worker                self.fail(f"'{call_type}' is not a valid call type")
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Worker        def test_func_expect_error(call_type, should_error):
1944*da0073e9SAndroid Build Coastguard Worker            self.check_nondeterministic_alert(
1945*da0073e9SAndroid Build Coastguard Worker                lambda: test_func(call_type),
1946*da0073e9SAndroid Build Coastguard Worker                'median CUDA with indices output',
1947*da0073e9SAndroid Build Coastguard Worker                should_error)
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Worker        is_cuda = torch.device(device).type == 'cuda'
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker        test_func_expect_error('function', False)
1952*da0073e9SAndroid Build Coastguard Worker        test_func_expect_error('function with indices', is_cuda)
1953*da0073e9SAndroid Build Coastguard Worker        test_func_expect_error('method', False)
1954*da0073e9SAndroid Build Coastguard Worker        test_func_expect_error('method with indices', is_cuda)
1955*da0073e9SAndroid Build Coastguard Worker        test_func_expect_error('out with indices', is_cuda)
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_scatter_gather_ops
1958*da0073e9SAndroid Build Coastguard Worker    def _test_gather_backward_one_dim(self, device, deterministic: bool = False) -> None:
1959*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(deterministic):
1960*da0073e9SAndroid Build Coastguard Worker            m = random.randint(2000, 3000)
1961*da0073e9SAndroid Build Coastguard Worker            elems = random.randint(10 * m, 20 * m)
1962*da0073e9SAndroid Build Coastguard Worker            dim = 0
1963*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(m, device=device, requires_grad=True)
1964*da0073e9SAndroid Build Coastguard Worker            idx = torch.randint(m, (elems,), device=device)
1965*da0073e9SAndroid Build Coastguard Worker            res = torch.gather(src, dim, idx)
1966*da0073e9SAndroid Build Coastguard Worker            weight = torch.rand_like(res, device=device) * 10 ** 6
1967*da0073e9SAndroid Build Coastguard Worker            res.backward(weight)
1968*da0073e9SAndroid Build Coastguard Worker            assert src.grad is not None
1969*da0073e9SAndroid Build Coastguard Worker            grad = src.grad.detach().clone()
1970*da0073e9SAndroid Build Coastguard Worker
1971*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type == 'cuda':
1972*da0073e9SAndroid Build Coastguard Worker                for _ in range(2):
1973*da0073e9SAndroid Build Coastguard Worker                    src.grad.data.zero_()
1974*da0073e9SAndroid Build Coastguard Worker                    res = torch.gather(src, dim, idx)
1975*da0073e9SAndroid Build Coastguard Worker                    res.backward(weight)
1976*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(src.grad, grad, atol=0, rtol=0)
1977*da0073e9SAndroid Build Coastguard Worker            else:
1978*da0073e9SAndroid Build Coastguard Worker                expected = torch.zeros_like(src, device=device)
1979*da0073e9SAndroid Build Coastguard Worker                for i in range(elems):
1980*da0073e9SAndroid Build Coastguard Worker                    expected[idx[i]] += weight[i]
1981*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, expected, atol=0, rtol=0)
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_scatter_gather_ops
1984*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1985*da0073e9SAndroid Build Coastguard Worker    def test_gather_backward_deterministic_path(self, device) -> None:
1986*da0073e9SAndroid Build Coastguard Worker        self._test_gather_backward_one_dim(device, True)
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_scatter_gather_ops
1989*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1990*da0073e9SAndroid Build Coastguard Worker    def test_gather_backward_one_dim(self, device) -> None:
1991*da0073e9SAndroid Build Coastguard Worker        self._test_gather_backward_one_dim(device, False)
1992*da0073e9SAndroid Build Coastguard Worker
1993*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_scatter_gather_ops
1994*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1995*da0073e9SAndroid Build Coastguard Worker    def test_scatter_add_one_dim_deterministic(self, device) -> None:
1996*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(True):
1997*da0073e9SAndroid Build Coastguard Worker            m = random.randint(20, 30)
1998*da0073e9SAndroid Build Coastguard Worker            elems = random.randint(2000 * m, 3000 * m)
1999*da0073e9SAndroid Build Coastguard Worker            dim = 0
2000*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(elems, device=device)
2001*da0073e9SAndroid Build Coastguard Worker            idx = torch.randint(m, (elems,), device=device)
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(m, device=device)
2004*da0073e9SAndroid Build Coastguard Worker            res = x.scatter_add(dim, idx, src)
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Worker            # Checking if scatter_add is deterministic
2007*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
2008*da0073e9SAndroid Build Coastguard Worker                res_next = x.scatter_add(dim, idx, src)
2009*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, res_next, atol=0, rtol=0)
2010*da0073e9SAndroid Build Coastguard Worker                res = res_next
2011*da0073e9SAndroid Build Coastguard Worker
2012*da0073e9SAndroid Build Coastguard Worker            expected = torch.zeros(m, device=device)
2013*da0073e9SAndroid Build Coastguard Worker            for i in range(elems):
2014*da0073e9SAndroid Build Coastguard Worker                expected[idx[i]] += src[i]
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected, atol=1e-4, rtol=1e-5)
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_scatter_gather_ops
2019*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2020*da0073e9SAndroid Build Coastguard Worker    def test_scatter_zero_size_index(self, device) -> None:
2021*da0073e9SAndroid Build Coastguard Worker        null_index = torch.zeros((0, 4), dtype=torch.int64)
2022*da0073e9SAndroid Build Coastguard Worker        null_arr = torch.zeros((0, 4))
2023*da0073e9SAndroid Build Coastguard Worker        original = torch.arange(4, dtype=torch.float32)
2024*da0073e9SAndroid Build Coastguard Worker        result = original.scatter(0, null_index, null_arr)
2025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, original, atol=0, rtol=0)
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2028*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("FIXME")
2029*da0073e9SAndroid Build Coastguard Worker    def test_sync_warning(self, device):
2030*da0073e9SAndroid Build Coastguard Worker
2031*da0073e9SAndroid Build Coastguard Worker        def _sync_raises_helper(f, level):
2032*da0073e9SAndroid Build Coastguard Worker            with CudaSyncGuard(level):
2033*da0073e9SAndroid Build Coastguard Worker                if level == 1:
2034*da0073e9SAndroid Build Coastguard Worker                    with self.assertWarnsRegex(UserWarning, "called a synchronizing "):
2035*da0073e9SAndroid Build Coastguard Worker                        f()
2036*da0073e9SAndroid Build Coastguard Worker                elif level == 2:
2037*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "called a synchronizing "):
2038*da0073e9SAndroid Build Coastguard Worker                        f()
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker        def _no_sync_helper(f, level):
2041*da0073e9SAndroid Build Coastguard Worker            with CudaSyncGuard(level):
2042*da0073e9SAndroid Build Coastguard Worker                f()
2043*da0073e9SAndroid Build Coastguard Worker
2044*da0073e9SAndroid Build Coastguard Worker        def _ind_put_fn(x, ind, val):
2045*da0073e9SAndroid Build Coastguard Worker            x[ind] = val
2046*da0073e9SAndroid Build Coastguard Worker            return x
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker        def _ind_get_fn(x, ind):
2049*da0073e9SAndroid Build Coastguard Worker            return x[ind]
2050*da0073e9SAndroid Build Coastguard Worker
2051*da0073e9SAndroid Build Coastguard Worker        def _cond_fn(x):
2052*da0073e9SAndroid Build Coastguard Worker            if x:  # taking boolean value of a tensor synchronizes
2053*da0073e9SAndroid Build Coastguard Worker                return x
2054*da0073e9SAndroid Build Coastguard Worker            else:
2055*da0073e9SAndroid Build Coastguard Worker                return 2 * x
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker        # prepare inputs for subsequent ops
2058*da0073e9SAndroid Build Coastguard Worker        size = 4
2059*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(size, device=device)
2060*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((), device=device)
2061*da0073e9SAndroid Build Coastguard Worker        ind = torch.randint(size, (3,), device=device)
2062*da0073e9SAndroid Build Coastguard Worker        ind_cpu = ind.cpu()
2063*da0073e9SAndroid Build Coastguard Worker        repeats = torch.full((1,), 2, device=device)
2064*da0073e9SAndroid Build Coastguard Worker        mask = torch.randint(2, (size,), device=device, dtype=bool)
2065*da0073e9SAndroid Build Coastguard Worker        expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
2066*da0073e9SAndroid Build Coastguard Worker                          lambda: _ind_put_fn(x, ind, y),
2067*da0073e9SAndroid Build Coastguard Worker                          lambda: _ind_get_fn(x, ind),
2068*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.nn.functional.one_hot(ind, num_classes=size),
2069*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.randperm(20000, device=device),
2070*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.repeat_interleave(x, 2, output_size=2 * size),
2071*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.repeat_interleave(x, repeats, output_size=2 * size),
2072*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.any(y))
2073*da0073e9SAndroid Build Coastguard Worker        expect_sync = (lambda: _ind_put_fn(x, mask, y),
2074*da0073e9SAndroid Build Coastguard Worker                       lambda: _ind_put_fn(x, ind_cpu, y),
2075*da0073e9SAndroid Build Coastguard Worker                       lambda: _ind_get_fn(x, mask),
2076*da0073e9SAndroid Build Coastguard Worker                       lambda: _ind_get_fn(x, ind_cpu),
2077*da0073e9SAndroid Build Coastguard Worker                       lambda: x.nonzero(),
2078*da0073e9SAndroid Build Coastguard Worker                       lambda: _cond_fn(y),
2079*da0073e9SAndroid Build Coastguard Worker                       lambda: torch.nn.functional.one_hot(ind),
2080*da0073e9SAndroid Build Coastguard Worker                       lambda: torch.repeat_interleave(x, repeats))
2081*da0073e9SAndroid Build Coastguard Worker        for f, level in product(expect_no_sync, (1, 2)):
2082*da0073e9SAndroid Build Coastguard Worker            _no_sync_helper(f, level)
2083*da0073e9SAndroid Build Coastguard Worker        for f, level in product(expect_sync, (1, 2)):
2084*da0073e9SAndroid Build Coastguard Worker            _sync_raises_helper(f, level)
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2088*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2089*da0073e9SAndroid Build Coastguard Worker    def test_log_normal(self, device, dtype):
2090*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([10], dtype=dtype, device=device).log_normal_()
2091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.dtype, dtype)
2092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(), torch.Size([1]))
2093*da0073e9SAndroid Build Coastguard Worker
2094*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2095*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2096*da0073e9SAndroid Build Coastguard Worker    def test_geometric(self, device, dtype):
2097*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5)
2098*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.dtype, dtype)
2099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(), torch.Size([1]))
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2102*da0073e9SAndroid Build Coastguard Worker    def test_repeat_interleave(self, device):
2103*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([[1, 2], [3, 4]], device=device)
2104*da0073e9SAndroid Build Coastguard Worker        # exercise single argument function signature
2105*da0073e9SAndroid Build Coastguard Worker        temp = y.repeat_interleave(2)
2106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([8]), temp.size())
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.int, torch.long]:
2109*da0073e9SAndroid Build Coastguard Worker            lengths = torch.tensor([1, 2], dtype=dtype, device=device)
2110*da0073e9SAndroid Build Coastguard Worker            output_size = torch.sum(lengths)
2111*da0073e9SAndroid Build Coastguard Worker            a = torch.repeat_interleave(
2112*da0073e9SAndroid Build Coastguard Worker                y,
2113*da0073e9SAndroid Build Coastguard Worker                lengths,
2114*da0073e9SAndroid Build Coastguard Worker                dim=0,
2115*da0073e9SAndroid Build Coastguard Worker            )
2116*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.dtype, y.dtype)
2117*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.size(), torch.Size([3, 2]))
2118*da0073e9SAndroid Build Coastguard Worker
2119*da0073e9SAndroid Build Coastguard Worker            a_with_output = torch.repeat_interleave(
2120*da0073e9SAndroid Build Coastguard Worker                y,
2121*da0073e9SAndroid Build Coastguard Worker                lengths,
2122*da0073e9SAndroid Build Coastguard Worker                dim=0,
2123*da0073e9SAndroid Build Coastguard Worker                output_size=output_size,
2124*da0073e9SAndroid Build Coastguard Worker            )
2125*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_with_output.dtype, y.dtype)
2126*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
2127*da0073e9SAndroid Build Coastguard Worker
2128*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types())
2129*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(*floating_types_and(torch.bfloat16, torch.half))
2130*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half))
2131*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_p(self, device, dtype):
2132*da0073e9SAndroid Build Coastguard Worker        for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):
2133*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(trivial_p, dtype=dtype, device=device)
2134*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.bernoulli().tolist(), trivial_p)
2135*da0073e9SAndroid Build Coastguard Worker
2136*da0073e9SAndroid Build Coastguard Worker        def isBinary(t):
2137*da0073e9SAndroid Build Coastguard Worker            return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(5, 5, dtype=dtype, device=device)
2140*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isBinary(p.bernoulli()))
2141*da0073e9SAndroid Build Coastguard Worker
2142*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(5, dtype=dtype, device=device).expand(5, 5)
2143*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isBinary(p.bernoulli()))
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(5, 5, dtype=dtype, device=device)
2146*da0073e9SAndroid Build Coastguard Worker        torch.bernoulli(torch.rand_like(p), out=p)
2147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isBinary(p))
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker    # RngUniform not implemented for Integral type in XLA test
2150*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types())
2151*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(*all_types_and(torch.bool, torch.half))
2152*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*all_types_and(torch.bool, torch.half))
2153*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_self(self, device, dtype):
2154*da0073e9SAndroid Build Coastguard Worker
2155*da0073e9SAndroid Build Coastguard Worker        def isBinary(t):
2156*da0073e9SAndroid Build Coastguard Worker            return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
2157*da0073e9SAndroid Build Coastguard Worker
2158*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(10, 10, dtype=dtype, device=device)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker        t.fill_(2)
2161*da0073e9SAndroid Build Coastguard Worker        t.bernoulli_(0.5)
2162*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isBinary(t))
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker        for p_dtype in floating_types_and(*[torch.half] if device.startswith('cuda') else []):
2165*da0073e9SAndroid Build Coastguard Worker            p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10)
2166*da0073e9SAndroid Build Coastguard Worker            t.fill_(2)
2167*da0073e9SAndroid Build Coastguard Worker            t.bernoulli_(p)
2168*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isBinary(t))
2169*da0073e9SAndroid Build Coastguard Worker
2170*da0073e9SAndroid Build Coastguard Worker            t.fill_(2)
2171*da0073e9SAndroid Build Coastguard Worker            torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t)
2172*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isBinary(t))
2173*da0073e9SAndroid Build Coastguard Worker
2174*da0073e9SAndroid Build Coastguard Worker            t.fill_(2)
2175*da0073e9SAndroid Build Coastguard Worker            t.bernoulli_(torch.rand_like(t, dtype=p_dtype))
2176*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isBinary(t))
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker    @slowTest
2179*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
2180*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half))
2181*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_edge_cases(self, device, dtype):
2182*da0073e9SAndroid Build Coastguard Worker        # Need to draw a lot of samples to cover every random floating point number.
2183*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(10000, 10000, dtype=dtype, device=device)  # probability of drawing "1" is 0
2184*da0073e9SAndroid Build Coastguard Worker        num_ones = (torch.bernoulli(a) == 1).sum()
2185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_ones, 0)
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(10000, 10000, dtype=dtype, device=device)  # probability of drawing "1" is 1
2188*da0073e9SAndroid Build Coastguard Worker        num_zeros = (torch.bernoulli(b) == 0).sum()
2189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_zeros, 0)
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2192*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2193*da0073e9SAndroid Build Coastguard Worker    def test_exponential(self, device, dtype):
2194*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([10], dtype=dtype, device=device).exponential_(0.5)
2195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.dtype, dtype)
2196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(), torch.Size([1]))
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker        # Tests extremal behavior
2199*da0073e9SAndroid Build Coastguard Worker        t = torch.empty((1,), device=device, dtype=dtype).exponential_(float('inf'))
2200*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.item() == 0)
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker        # Tests that negative lambda fails
2203*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
2204*da0073e9SAndroid Build Coastguard Worker            torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5)
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2207*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float)
2208*da0073e9SAndroid Build Coastguard Worker    def test_exponential_no_zero(self, device, dtype):
2209*da0073e9SAndroid Build Coastguard Worker        # naively, 0 in exponential can be generated with probability 2^-24
2210*da0073e9SAndroid Build Coastguard Worker        # so we need more samples to check if it's not generated
2211*da0073e9SAndroid Build Coastguard Worker        # instead of doing one
2212*da0073e9SAndroid Build Coastguard Worker        # don't test CPU, that would be a long test
2213*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
2214*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.min() > 0)
2215*da0073e9SAndroid Build Coastguard Worker
2216*da0073e9SAndroid Build Coastguard Worker    def _generate_correlation_tensors(self, device, dtype):
2217*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((0, 0), dtype=dtype, device=device)
2218*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((1, 0), dtype=dtype, device=device)
2219*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((0, 1), dtype=dtype, device=device)
2220*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((2,), dtype=dtype, device=device)
2221*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((2, 1), dtype=dtype, device=device)
2222*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((2, 2), dtype=dtype, device=device)
2223*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((2, 3), dtype=dtype, device=device)
2224*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((5, 10), dtype=dtype, device=device)
2225*da0073e9SAndroid Build Coastguard Worker        yield make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
2226*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.int:
2227*da0073e9SAndroid Build Coastguard Worker            yield torch.tensor([0, -2, nan, 10.2, inf], dtype=dtype, device=device)
2228*da0073e9SAndroid Build Coastguard Worker
2229*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2230*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int, torch.float, torch.cfloat)
2231*da0073e9SAndroid Build Coastguard Worker    def test_corrcoef(self, device, dtype):
2232*da0073e9SAndroid Build Coastguard Worker        for x in self._generate_correlation_tensors(device, dtype):
2233*da0073e9SAndroid Build Coastguard Worker            res = torch.corrcoef(x)
2234*da0073e9SAndroid Build Coastguard Worker            ref = np.corrcoef(x.cpu().numpy())
2235*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, ref, exact_dtype=False)
2236*da0073e9SAndroid Build Coastguard Worker
2237*da0073e9SAndroid Build Coastguard Worker    @skipRocmIfTorchInductor
2238*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int, torch.float, torch.cfloat)
2239*da0073e9SAndroid Build Coastguard Worker    def test_cov(self, device, dtype):
2240*da0073e9SAndroid Build Coastguard Worker        def check(t, correction=1, fweights=None, aweights=None):
2241*da0073e9SAndroid Build Coastguard Worker            res = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights)
2242*da0073e9SAndroid Build Coastguard Worker            t = t.cpu().numpy()
2243*da0073e9SAndroid Build Coastguard Worker            fweights = fweights.cpu().numpy() if fweights is not None else None
2244*da0073e9SAndroid Build Coastguard Worker            aweights = aweights.cpu().numpy() if aweights is not None else None
2245*da0073e9SAndroid Build Coastguard Worker            ref = np.cov(t, ddof=correction, fweights=fweights, aweights=aweights)
2246*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, ref, atol=1e-05, rtol=1e-05, exact_dtype=False)
2247*da0073e9SAndroid Build Coastguard Worker
2248*da0073e9SAndroid Build Coastguard Worker        for x in self._generate_correlation_tensors(device, dtype):
2249*da0073e9SAndroid Build Coastguard Worker            check(x)
2250*da0073e9SAndroid Build Coastguard Worker            num_observations = x.numel() if x.ndim < 2 else x.size(1)
2251*da0073e9SAndroid Build Coastguard Worker            if num_observations > 0:
2252*da0073e9SAndroid Build Coastguard Worker                fweights = torch.randint(1, 10, (num_observations,), device=device)
2253*da0073e9SAndroid Build Coastguard Worker                aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=1)
2254*da0073e9SAndroid Build Coastguard Worker                for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]):
2255*da0073e9SAndroid Build Coastguard Worker                    check(x, correction, fweights, aweights)
2256*da0073e9SAndroid Build Coastguard Worker
2257*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
2258*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2259*da0073e9SAndroid Build Coastguard Worker    def test_uniform_kstest(self, device, dtype):
2260*da0073e9SAndroid Build Coastguard Worker        from scipy import stats
2261*da0073e9SAndroid Build Coastguard Worker        size = 1000
2262*da0073e9SAndroid Build Coastguard Worker        for from_ in [-42, 0, 4.2]:
2263*da0073e9SAndroid Build Coastguard Worker            for to_ in [-4.2, 0, 42]:
2264*da0073e9SAndroid Build Coastguard Worker                if to_ > from_:
2265*da0073e9SAndroid Build Coastguard Worker                    t = torch.empty(size, dtype=dtype, device=device).uniform_(from_, to_)
2266*da0073e9SAndroid Build Coastguard Worker                    res = stats.kstest(t.cpu().to(torch.double), 'uniform', args=(from_, (to_ - from_)))
2267*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(res.statistic < 0.1)
2268*da0073e9SAndroid Build Coastguard Worker
2269*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
2270*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
2271*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
2272*da0073e9SAndroid Build Coastguard Worker    def test_normal_kstest(self, device, dtype):
2273*da0073e9SAndroid Build Coastguard Worker        from scipy import stats
2274*da0073e9SAndroid Build Coastguard Worker        size = 1000
2275*da0073e9SAndroid Build Coastguard Worker        for mean in [-10, 0, 50]:
2276*da0073e9SAndroid Build Coastguard Worker            for std in [1, 5, 10]:
2277*da0073e9SAndroid Build Coastguard Worker                t = torch.empty(size, dtype=dtype, device=device).normal_(mean=mean, std=std)
2278*da0073e9SAndroid Build Coastguard Worker                res = stats.kstest(t.cpu().to(torch.double), 'norm', args=(mean, std))
2279*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(res.statistic < 0.1)
2280*da0073e9SAndroid Build Coastguard Worker
2281*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2282*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
2283*da0073e9SAndroid Build Coastguard Worker    @skipRocmIfTorchInductor
2284*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2285*da0073e9SAndroid Build Coastguard Worker    def test_lognormal_kstest(self, device, dtype):
2286*da0073e9SAndroid Build Coastguard Worker        from scipy import stats
2287*da0073e9SAndroid Build Coastguard Worker        size = 1000
2288*da0073e9SAndroid Build Coastguard Worker        for mean in [-3, 0, 7]:
2289*da0073e9SAndroid Build Coastguard Worker            for std in [1, 5, 7]:
2290*da0073e9SAndroid Build Coastguard Worker                t = torch.empty(size, dtype=dtype, device=device).log_normal_(mean=mean, std=std)
2291*da0073e9SAndroid Build Coastguard Worker                res = stats.kstest(t.cpu().to(torch.double), 'lognorm', args=(std, 0, math.exp(mean)))
2292*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half:
2293*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(res.statistic < 0.3)
2294*da0073e9SAndroid Build Coastguard Worker                else:
2295*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(res.statistic < 0.1)
2296*da0073e9SAndroid Build Coastguard Worker
2297*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2298*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
2299*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2300*da0073e9SAndroid Build Coastguard Worker    def test_exponential_kstest(self, device, dtype):
2301*da0073e9SAndroid Build Coastguard Worker        from scipy import stats
2302*da0073e9SAndroid Build Coastguard Worker        size = 1000
2303*da0073e9SAndroid Build Coastguard Worker        for lambd in [0.5, 1.0, 5.0]:
2304*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(size, dtype=dtype, device=device).exponential_(lambd=lambd)
2305*da0073e9SAndroid Build Coastguard Worker            res = stats.kstest(t.cpu().to(torch.double), 'expon', args=(0, 1 / lambd,))
2306*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(res.statistic < 0.1)
2307*da0073e9SAndroid Build Coastguard Worker
2308*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2309*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
2310*da0073e9SAndroid Build Coastguard Worker    @skipRocmIfTorchInductor
2311*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2312*da0073e9SAndroid Build Coastguard Worker    def test_cauchy_kstest(self, device, dtype):
2313*da0073e9SAndroid Build Coastguard Worker        from scipy import stats
2314*da0073e9SAndroid Build Coastguard Worker        size = 1000
2315*da0073e9SAndroid Build Coastguard Worker        for median in [-10, 0, 50]:
2316*da0073e9SAndroid Build Coastguard Worker            for sigma in [0.5, 1.0, 10.0]:
2317*da0073e9SAndroid Build Coastguard Worker                t = torch.empty(size, dtype=dtype, device=device).cauchy_(median=median, sigma=sigma)
2318*da0073e9SAndroid Build Coastguard Worker                res = stats.kstest(t.cpu().to(torch.double), 'cauchy', args=(median, sigma))
2319*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(res.statistic < 0.1)
2320*da0073e9SAndroid Build Coastguard Worker
2321*da0073e9SAndroid Build Coastguard Worker    @slowTest
2322*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2323*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float32)
2324*da0073e9SAndroid Build Coastguard Worker    def test_cauchy_no_inf(self, device, dtype):
2325*da0073e9SAndroid Build Coastguard Worker        # torch.float16 will have `inf` because of its smaller range.
2326*da0073e9SAndroid Build Coastguard Worker        for _ in range((2**16) * 2):
2327*da0073e9SAndroid Build Coastguard Worker            x = torch.empty((2**16), dtype=dtype, device=device)
2328*da0073e9SAndroid Build Coastguard Worker            x.cauchy_()
2329*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.isinf().sum())
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
2332*da0073e9SAndroid Build Coastguard Worker    def test_cauchy(self, device, dtype):
2333*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([10], dtype=dtype, device=device).cauchy_(0.0, 0.5)
2334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.dtype, dtype)
2335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.size(), torch.Size([1]))
2336*da0073e9SAndroid Build Coastguard Worker
2337*da0073e9SAndroid Build Coastguard Worker        # Tests extremal behavior
2338*da0073e9SAndroid Build Coastguard Worker        t = torch.empty((1,), device=device, dtype=dtype).cauchy_(float('inf'), 0.5)
2339*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.item() == float('inf'))
2340*da0073e9SAndroid Build Coastguard Worker
2341*da0073e9SAndroid Build Coastguard Worker        # Tests non-positive rate fails
2342*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
2343*da0073e9SAndroid Build Coastguard Worker            torch.empty((1,), device=device, dtype=dtype).cauchy_(0.0, 0.0)
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2346*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
2347*da0073e9SAndroid Build Coastguard Worker    @skipRocmIfTorchInductor
2348*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2349*da0073e9SAndroid Build Coastguard Worker    def test_geometric_kstest(self, device, dtype):
2350*da0073e9SAndroid Build Coastguard Worker        from scipy import stats
2351*da0073e9SAndroid Build Coastguard Worker        size = 1000
2352*da0073e9SAndroid Build Coastguard Worker        for p in [0.2, 0.5, 0.8]:
2353*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(size, dtype=dtype, device=device).geometric_(p=p)
2354*da0073e9SAndroid Build Coastguard Worker            actual = np.histogram(t.cpu().to(torch.double), np.arange(1, 100))[0]
2355*da0073e9SAndroid Build Coastguard Worker            expected = stats.geom(p).pmf(np.arange(1, 99)) * size
2356*da0073e9SAndroid Build Coastguard Worker            res = stats.chisquare(actual, expected)
2357*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.pvalue, 1.0, atol=0.1, rtol=0)
2358*da0073e9SAndroid Build Coastguard Worker
2359*da0073e9SAndroid Build Coastguard Worker    # FIXME: find test suite for pdist and cdist
2360*da0073e9SAndroid Build Coastguard Worker    def test_pairwise_distance_empty(self, device):
2361*da0073e9SAndroid Build Coastguard Worker        shape = (2, 0)
2362*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
2363*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(shape, device=device)
2364*da0073e9SAndroid Build Coastguard Worker
2365*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y))
2366*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker        shape = (0, 2)
2369*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
2370*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(shape, device=device)
2371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y))
2372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker    def test_pdist_empty(self, device):
2375*da0073e9SAndroid Build Coastguard Worker        shape = (0, 2)
2376*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
2377*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
2378*da0073e9SAndroid Build Coastguard Worker
2379*da0073e9SAndroid Build Coastguard Worker        shape = (1, 2)
2380*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
2381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker        shape = (3, 0)
2384*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
2385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
2386*da0073e9SAndroid Build Coastguard Worker
2387*da0073e9SAndroid Build Coastguard Worker    def test_cdist_empty(self, device):
2388*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((0, 5), device=device)
2389*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((4, 5), device=device)
2390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
2391*da0073e9SAndroid Build Coastguard Worker
2392*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((2, 5), device=device)
2393*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((0, 5), device=device)
2394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
2395*da0073e9SAndroid Build Coastguard Worker
2396*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((2, 0), device=device)
2397*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((3, 0), device=device)
2398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((2, 0), device=device)
2401*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((0, 0), device=device)
2402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
2403*da0073e9SAndroid Build Coastguard Worker
2404*da0073e9SAndroid Build Coastguard Worker    def _brute_cdist(self, x, y, p=2):
2405*da0073e9SAndroid Build Coastguard Worker        r1 = x.shape[-2]
2406*da0073e9SAndroid Build Coastguard Worker        r2 = y.shape[-2]
2407*da0073e9SAndroid Build Coastguard Worker        if r1 == 0 or r2 == 0:
2408*da0073e9SAndroid Build Coastguard Worker            return torch.empty(r1, r2, device=x.device)
2409*da0073e9SAndroid Build Coastguard Worker        return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
2410*da0073e9SAndroid Build Coastguard Worker
2411*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2412*da0073e9SAndroid Build Coastguard Worker    def test_cdist_norm(self, device):
2413*da0073e9SAndroid Build Coastguard Worker        for r1 in [3, 4, 5, 6]:
2414*da0073e9SAndroid Build Coastguard Worker            for m in [2, 3, 4, 10]:
2415*da0073e9SAndroid Build Coastguard Worker                for r2 in [4, 6, 7, 8]:
2416*da0073e9SAndroid Build Coastguard Worker                    for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2417*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(r1, m, device=device)
2418*da0073e9SAndroid Build Coastguard Worker                        y = torch.randn(r2, m, device=device)
2419*da0073e9SAndroid Build Coastguard Worker                        if p == 2:
2420*da0073e9SAndroid Build Coastguard Worker                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2421*da0073e9SAndroid Build Coastguard Worker                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
2422*da0073e9SAndroid Build Coastguard Worker                                expected = self._brute_cdist(x, y, p=2)
2423*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
2424*da0073e9SAndroid Build Coastguard Worker                        else:
2425*da0073e9SAndroid Build Coastguard Worker                            actual = torch.cdist(x, y, p=p)
2426*da0073e9SAndroid Build Coastguard Worker                            expected = self._brute_cdist(x, y, p=p)
2427*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(expected, actual)
2428*da0073e9SAndroid Build Coastguard Worker
2429*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2430*da0073e9SAndroid Build Coastguard Worker    def test_cdist_norm_batch(self, device):
2431*da0073e9SAndroid Build Coastguard Worker        for r1 in [3, 4, 5, 6]:
2432*da0073e9SAndroid Build Coastguard Worker            for m in [2, 3, 4, 10]:
2433*da0073e9SAndroid Build Coastguard Worker                for r2 in [4, 6, 7, 8]:
2434*da0073e9SAndroid Build Coastguard Worker                    for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2435*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(2, 3, 6, r1, m, device=device)
2436*da0073e9SAndroid Build Coastguard Worker                        y = torch.randn(2, 3, 6, r2, m, device=device)
2437*da0073e9SAndroid Build Coastguard Worker                        if p == 2:
2438*da0073e9SAndroid Build Coastguard Worker                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2439*da0073e9SAndroid Build Coastguard Worker                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
2440*da0073e9SAndroid Build Coastguard Worker                                expected = self._brute_cdist(x, y, p=2)
2441*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
2442*da0073e9SAndroid Build Coastguard Worker                        else:
2443*da0073e9SAndroid Build Coastguard Worker                            actual = torch.cdist(x, y, p=p)
2444*da0073e9SAndroid Build Coastguard Worker                            expected = self._brute_cdist(x, y, p=p)
2445*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(expected, actual)
2446*da0073e9SAndroid Build Coastguard Worker
2447*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2448*da0073e9SAndroid Build Coastguard Worker    def test_cdist_cuda_backward(self, device):
2449*da0073e9SAndroid Build Coastguard Worker        for l1 in [1, 511, 513]:
2450*da0073e9SAndroid Build Coastguard Worker            for l2 in [1, 511, 513]:
2451*da0073e9SAndroid Build Coastguard Worker                for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2452*da0073e9SAndroid Build Coastguard Worker                    x1 = torch.randn(4, l1, 32, device=device, requires_grad=True)
2453*da0073e9SAndroid Build Coastguard Worker                    x2 = x1.clone().detach_().requires_grad_()
2454*da0073e9SAndroid Build Coastguard Worker                    y1 = torch.randn(4, l2, 32, device=device, requires_grad=True)
2455*da0073e9SAndroid Build Coastguard Worker                    y2 = y1.clone().detach_().requires_grad_()
2456*da0073e9SAndroid Build Coastguard Worker                    if p == 2:
2457*da0073e9SAndroid Build Coastguard Worker                        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2458*da0073e9SAndroid Build Coastguard Worker                            z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean()
2459*da0073e9SAndroid Build Coastguard Worker                            z2 = self._brute_cdist(x2, y2, p=2).mean()
2460*da0073e9SAndroid Build Coastguard Worker                            z1.backward()
2461*da0073e9SAndroid Build Coastguard Worker                            z2.backward()
2462*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
2463*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
2464*da0073e9SAndroid Build Coastguard Worker                    else:
2465*da0073e9SAndroid Build Coastguard Worker                        z1 = torch.cdist(x1, y1, p=p).mean()
2466*da0073e9SAndroid Build Coastguard Worker                        z2 = self._brute_cdist(x2, y2, p=p).mean()
2467*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
2468*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
2469*da0073e9SAndroid Build Coastguard Worker
2470*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
2471*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
2472*da0073e9SAndroid Build Coastguard Worker    def test_cdist_large(self, device):
2473*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2474*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1000, 10, device=device)
2475*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(1000, 10, device=device)
2476*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2477*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2478*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker    @slowTest
2481*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.01)
2482*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.01)
2483*da0073e9SAndroid Build Coastguard Worker    def test_cdist_large_batch(self, device):
2484*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2485*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, 3, 1000, 10, device=device)
2486*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 3, 1000, 10, device=device)
2487*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2488*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2490*da0073e9SAndroid Build Coastguard Worker
2491*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
2492*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
2493*da0073e9SAndroid Build Coastguard Worker    def test_cdist_non_contiguous(self, device):
2494*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2495*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 7, device=device).mT
2496*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(5, 3, device=device).mT
2497*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2498*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2499*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
2500*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
2501*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2502*da0073e9SAndroid Build Coastguard Worker
2503*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(7, 5, device=device)
2504*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(5, 3, device=device).t()
2505*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2506*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2507*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_contiguous())
2508*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
2509*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2510*da0073e9SAndroid Build Coastguard Worker
2511*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 7, device=device).t()
2512*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3, 5, device=device)
2513*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2514*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2515*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
2516*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.is_contiguous())
2517*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
2520*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
2521*da0073e9SAndroid Build Coastguard Worker    def test_cdist_non_contiguous_batch(self, device):
2522*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2523*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, 3, 2, 5, 7, device=device).mT
2524*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 3, 2, 5, 3, device=device).mT
2525*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2526*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2527*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
2528*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
2529*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(7, 2, 7, 5, device=device)
2532*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(7, 2, 5, 3, device=device).mT
2533*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2534*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2535*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_contiguous())
2536*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
2537*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2538*da0073e9SAndroid Build Coastguard Worker
2539*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, 5, 7, device=device).mT
2540*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 3, 5, device=device)
2541*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
2542*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
2543*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
2544*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.is_contiguous())
2545*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2546*da0073e9SAndroid Build Coastguard Worker
2547*da0073e9SAndroid Build Coastguard Worker    # Maybe merge into OpInfo?
2548*da0073e9SAndroid Build Coastguard Worker    def test_cdist_euclidean_large(self, device):
2549*da0073e9SAndroid Build Coastguard Worker        def _test_euclidean_large_cdist(sizex, sizey=None):
2550*da0073e9SAndroid Build Coastguard Worker            if sizey is None:
2551*da0073e9SAndroid Build Coastguard Worker                sizey = sizex
2552*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(sizex, device=device, dtype=torch.float)
2553*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(sizey, device=device, dtype=torch.float)
2554*da0073e9SAndroid Build Coastguard Worker            eps = 1e-6
2555*da0073e9SAndroid Build Coastguard Worker            # to avoid extremum
2556*da0073e9SAndroid Build Coastguard Worker            x = x - (((x - y) < eps).float() * 2 * eps)
2557*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
2558*da0073e9SAndroid Build Coastguard Worker            y.requires_grad = True
2559*da0073e9SAndroid Build Coastguard Worker            dist = torch.cdist(x, y, p=2)
2560*da0073e9SAndroid Build Coastguard Worker            # Do a backward pass to check that it is valid for large
2561*da0073e9SAndroid Build Coastguard Worker            # matrices
2562*da0073e9SAndroid Build Coastguard Worker            loss = dist.sum()
2563*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker        _test_euclidean_large_cdist((2000, 5))
2566*da0073e9SAndroid Build Coastguard Worker
2567*da0073e9SAndroid Build Coastguard Worker    # Ensure that cdist backward with p<1 does not produce NaNs
2568*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2569*da0073e9SAndroid Build Coastguard Worker    def test_cdist_grad_p_lt_1_no_nan(self, device):
2570*da0073e9SAndroid Build Coastguard Worker        for p in [0.99, 0.7, 0.5, 0.1, 0.01]:
2571*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 2, device=device)
2572*da0073e9SAndroid Build Coastguard Worker            y = x.clone().detach() + torch.tensor([[1., 0.]], device=device)
2573*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
2574*da0073e9SAndroid Build Coastguard Worker            y.requires_grad = True
2575*da0073e9SAndroid Build Coastguard Worker            result = torch.cdist(x, y, p=p)
2576*da0073e9SAndroid Build Coastguard Worker            result.backward(torch.ones_like(result))
2577*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.isnan(x.grad).any())
2578*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.isnan(y.grad).any())
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker    def test_cdist_same_inputs(self, device):
2581*da0073e9SAndroid Build Coastguard Worker        # Test to detect issues in cdist gradient calculation
2582*da0073e9SAndroid Build Coastguard Worker        # When the distances are 0
2583*da0073e9SAndroid Build Coastguard Worker        sizex = (1, 27, 32)
2584*da0073e9SAndroid Build Coastguard Worker        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2585*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(sizex, device=device, dtype=torch.float)
2586*da0073e9SAndroid Build Coastguard Worker            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
2587*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
2588*da0073e9SAndroid Build Coastguard Worker            eps = 1e-6
2589*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
2590*da0073e9SAndroid Build Coastguard Worker            d = torch.cdist(x, y)
2591*da0073e9SAndroid Build Coastguard Worker            d.backward(dist_grad)
2592*da0073e9SAndroid Build Coastguard Worker            # Check that the backward passs does not contain invalid
2593*da0073e9SAndroid Build Coastguard Worker            # values such as nan or inf
2594*da0073e9SAndroid Build Coastguard Worker            assert torch.isfinite(x.grad).all()
2595*da0073e9SAndroid Build Coastguard Worker
2596*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2597*da0073e9SAndroid Build Coastguard Worker    def test_cumsum(self, device):
2598*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 100, device=device)
2599*da0073e9SAndroid Build Coastguard Worker        res1 = torch.cumsum(x, 1)
2600*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor([]).to(device)
2601*da0073e9SAndroid Build Coastguard Worker        torch.cumsum(x, 1, out=res2)
2602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
2603*da0073e9SAndroid Build Coastguard Worker        x.cumsum_(1)
2604*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, x)
2605*da0073e9SAndroid Build Coastguard Worker
2606*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[True, False, True],
2607*da0073e9SAndroid Build Coastguard Worker                          [False, False, False],
2608*da0073e9SAndroid Build Coastguard Worker                          [True, True, True]], device=device)
2609*da0073e9SAndroid Build Coastguard Worker        b = a.byte()
2610*da0073e9SAndroid Build Coastguard Worker        aRes = torch.cumsum(a, 0)
2611*da0073e9SAndroid Build Coastguard Worker        bRes = torch.cumsum(b, 0)
2612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, bRes)
2613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, torch.tensor([[1, 0, 1],
2614*da0073e9SAndroid Build Coastguard Worker                                             [1, 0, 1],
2615*da0073e9SAndroid Build Coastguard Worker                                             [2, 1, 2]]))
2616*da0073e9SAndroid Build Coastguard Worker
2617*da0073e9SAndroid Build Coastguard Worker        aRes = torch.cumsum(a, 1)
2618*da0073e9SAndroid Build Coastguard Worker        bRes = torch.cumsum(b, 1)
2619*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, bRes)
2620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, torch.tensor([[1, 1, 2],
2621*da0073e9SAndroid Build Coastguard Worker                                             [0, 0, 0],
2622*da0073e9SAndroid Build Coastguard Worker                                             [1, 2, 3]]))
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker        # Check that cummulative sum over a zero length dimension doesn't crash on backprop.
2625*da0073e9SAndroid Build Coastguard Worker        # Also check that cumsum over other dimensions in a tensor with a zero-length
2626*da0073e9SAndroid Build Coastguard Worker        # dimensiuon also works
2627*da0073e9SAndroid Build Coastguard Worker        # Also include a basic suite of similar tests for other bases cases.
2628*da0073e9SAndroid Build Coastguard Worker        shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2629*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2630*da0073e9SAndroid Build Coastguard Worker            for dim in range(len(shape)):
2631*da0073e9SAndroid Build Coastguard Worker                raw_tensor = torch.zeros(*shape, requires_grad=True)
2632*da0073e9SAndroid Build Coastguard Worker                integrated = raw_tensor.cumsum(dim=dim)
2633*da0073e9SAndroid Build Coastguard Worker                # Check that backward does not crash
2634*da0073e9SAndroid Build Coastguard Worker                integrated.sum().backward()
2635*da0073e9SAndroid Build Coastguard Worker                # Check that output maintained correct shape
2636*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2637*da0073e9SAndroid Build Coastguard Worker
2638*da0073e9SAndroid Build Coastguard Worker        # Check a scalar example
2639*da0073e9SAndroid Build Coastguard Worker        raw_tensor = torch.tensor(3., requires_grad=True)
2640*da0073e9SAndroid Build Coastguard Worker        integrated = raw_tensor.cumsum(dim=-1)
2641*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(raw_tensor, integrated)
2642*da0073e9SAndroid Build Coastguard Worker        # Check that backward does not crash
2643*da0073e9SAndroid Build Coastguard Worker        integrated.sum().backward()
2644*da0073e9SAndroid Build Coastguard Worker        # Check that output maintained correct shape
2645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2648*da0073e9SAndroid Build Coastguard Worker    def test_cumprod(self, device):
2649*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 100, device=device)
2650*da0073e9SAndroid Build Coastguard Worker        res1 = torch.cumprod(x, 1)
2651*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor([]).to(device)
2652*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_TORCHINDUCTOR:
2653*da0073e9SAndroid Build Coastguard Worker            torch.cumprod(x, 1, out=res2)
2654*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
2655*da0073e9SAndroid Build Coastguard Worker        x.cumprod_(1)
2656*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, x)
2657*da0073e9SAndroid Build Coastguard Worker
2658*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[True, False, True],
2659*da0073e9SAndroid Build Coastguard Worker                          [False, False, False],
2660*da0073e9SAndroid Build Coastguard Worker                          [True, True, True]], dtype=torch.bool, device=device)
2661*da0073e9SAndroid Build Coastguard Worker        b = a.byte()
2662*da0073e9SAndroid Build Coastguard Worker        aRes = torch.cumprod(a, 0)
2663*da0073e9SAndroid Build Coastguard Worker        bRes = torch.cumprod(b, 0)
2664*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, bRes)
2665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, torch.tensor([[1, 0, 1],
2666*da0073e9SAndroid Build Coastguard Worker                                             [0, 0, 0],
2667*da0073e9SAndroid Build Coastguard Worker                                             [0, 0, 0]]))
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker        aRes = torch.cumprod(a, 1)
2670*da0073e9SAndroid Build Coastguard Worker        bRes = torch.cumprod(b, 1)
2671*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, bRes)
2672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aRes, torch.tensor([[1, 0, 0],
2673*da0073e9SAndroid Build Coastguard Worker                                             [0, 0, 0],
2674*da0073e9SAndroid Build Coastguard Worker                                             [1, 1, 1]]))
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker        # Check that cummulative prod over a zero length dimension doesn't crash on backprop.
2677*da0073e9SAndroid Build Coastguard Worker        # Also check that cumprod over other dimensions in a tensor with a zero-length
2678*da0073e9SAndroid Build Coastguard Worker        # dimensiuon also works
2679*da0073e9SAndroid Build Coastguard Worker        # Also include a basic suite of similar tests for other bases cases.
2680*da0073e9SAndroid Build Coastguard Worker        shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2681*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2682*da0073e9SAndroid Build Coastguard Worker            for dim in range(len(shape)):
2683*da0073e9SAndroid Build Coastguard Worker                raw_tensor = torch.zeros(*shape, requires_grad=True)
2684*da0073e9SAndroid Build Coastguard Worker                integrated = raw_tensor.cumprod(dim=dim)
2685*da0073e9SAndroid Build Coastguard Worker                # Check that backward does not crash
2686*da0073e9SAndroid Build Coastguard Worker                integrated.sum().backward()
2687*da0073e9SAndroid Build Coastguard Worker                # Check that output maintained correct shape
2688*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2689*da0073e9SAndroid Build Coastguard Worker
2690*da0073e9SAndroid Build Coastguard Worker        # Check a scalar example
2691*da0073e9SAndroid Build Coastguard Worker        raw_tensor = torch.tensor(3., requires_grad=True)
2692*da0073e9SAndroid Build Coastguard Worker        integrated = raw_tensor.cumprod(dim=-1)
2693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(raw_tensor, integrated)
2694*da0073e9SAndroid Build Coastguard Worker        # Check that backward does not crash
2695*da0073e9SAndroid Build Coastguard Worker        integrated.sum().backward()
2696*da0073e9SAndroid Build Coastguard Worker        # Check that output maintained correct shape
2697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2698*da0073e9SAndroid Build Coastguard Worker
2699*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2700*da0073e9SAndroid Build Coastguard Worker    def test_cummax_cummin(self, device):
2701*da0073e9SAndroid Build Coastguard Worker        def test_ops(op, string_of_function_name, expected_output1, expected_output2):
2702*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(100, 100, device=device)
2703*da0073e9SAndroid Build Coastguard Worker            out1 = op(x, 1)
2704*da0073e9SAndroid Build Coastguard Worker            res2 = torch.empty(0, device=device)
2705*da0073e9SAndroid Build Coastguard Worker            indices2 = torch.empty(0, dtype=torch.int64, device=device)
2706*da0073e9SAndroid Build Coastguard Worker            op(x, 1, out=(res2, indices2))
2707*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out1[0], res2)
2708*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out1[1], indices2)
2709*da0073e9SAndroid Build Coastguard Worker
2710*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[True, False, True],
2711*da0073e9SAndroid Build Coastguard Worker                              [False, False, False],
2712*da0073e9SAndroid Build Coastguard Worker                              [True, True, True]], dtype=torch.bool, device=device)
2713*da0073e9SAndroid Build Coastguard Worker            b = a.byte()
2714*da0073e9SAndroid Build Coastguard Worker            aRes = op(a, 0)
2715*da0073e9SAndroid Build Coastguard Worker            bRes = op(b, 0)
2716*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(aRes[0], bRes[0].bool())
2717*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(aRes[0], expected_output1.bool())
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker            # test inf and nan input
2720*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
2721*da0073e9SAndroid Build Coastguard Worker            xRes = op(x, 0)[0]
2722*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xRes, expected_output2)
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker            # op shouldn't support values, indices with a dtype, device type or layout
2725*da0073e9SAndroid Build Coastguard Worker            # different from that of input tensor
2726*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(10)
2727*da0073e9SAndroid Build Coastguard Worker            values = torch.empty(0, dtype=torch.int16)
2728*da0073e9SAndroid Build Coastguard Worker            indices = torch.empty(0, dtype=torch.int64)
2729*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2730*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
2731*da0073e9SAndroid Build Coastguard Worker                    'expected scalar_type Float but found Short'):
2732*da0073e9SAndroid Build Coastguard Worker                op(t, 0, out=(values, indices))
2733*da0073e9SAndroid Build Coastguard Worker
2734*da0073e9SAndroid Build Coastguard Worker            # Check that op over a zero length dimension doesn't crash on backprop.
2735*da0073e9SAndroid Build Coastguard Worker            # Also check that op over other dimensions in a tensor with a zero-length
2736*da0073e9SAndroid Build Coastguard Worker            # dimension also works
2737*da0073e9SAndroid Build Coastguard Worker            # Also include a basic suite of similar tests for other bases cases.
2738*da0073e9SAndroid Build Coastguard Worker            shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2739*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
2740*da0073e9SAndroid Build Coastguard Worker                for dim in range(len(shape)):
2741*da0073e9SAndroid Build Coastguard Worker                    raw_tensor = torch.zeros(*shape, requires_grad=True)
2742*da0073e9SAndroid Build Coastguard Worker                    integrated = getattr(raw_tensor, string_of_function_name)(dim=dim)
2743*da0073e9SAndroid Build Coastguard Worker                    # Check that backward does not crash
2744*da0073e9SAndroid Build Coastguard Worker                    integrated[0].sum().backward()
2745*da0073e9SAndroid Build Coastguard Worker                    # Check that output maintained correct shape
2746*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2747*da0073e9SAndroid Build Coastguard Worker
2748*da0073e9SAndroid Build Coastguard Worker            # Check a scalar example
2749*da0073e9SAndroid Build Coastguard Worker            raw_tensor = torch.tensor(3., requires_grad=True)
2750*da0073e9SAndroid Build Coastguard Worker            integrated = getattr(raw_tensor, string_of_function_name)(dim=-1)
2751*da0073e9SAndroid Build Coastguard Worker            # Check that backward does not crash
2752*da0073e9SAndroid Build Coastguard Worker            integrated[0].sum().backward()
2753*da0073e9SAndroid Build Coastguard Worker            # Check that output maintained correct shape
2754*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.tensor([4, inf, inf, inf, inf, nan, nan])
2757*da0073e9SAndroid Build Coastguard Worker        test_ops(torch.cummax, "cummax", torch.tensor([[1, 0, 1],
2758*da0073e9SAndroid Build Coastguard Worker                                                       [1, 0, 1],
2759*da0073e9SAndroid Build Coastguard Worker                                                       [1, 1, 1]]), expected_out)
2760*da0073e9SAndroid Build Coastguard Worker
2761*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.tensor([4, 4, 1.5, -inf, -inf, nan, nan])
2762*da0073e9SAndroid Build Coastguard Worker        test_ops(torch.cummin, "cummin", torch.tensor([[1, 0, 1],
2763*da0073e9SAndroid Build Coastguard Worker                                                       [0, 0, 0],
2764*da0073e9SAndroid Build Coastguard Worker                                                       [0, 0, 0]]), expected_out)
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
2767*da0073e9SAndroid Build Coastguard Worker    def test_logcumsumexp(self, device):
2768*da0073e9SAndroid Build Coastguard Worker        def logcumsumexp(a, axis):
2769*da0073e9SAndroid Build Coastguard Worker            return torch.cumsum(a.exp(), axis=axis).log_()
2770*da0073e9SAndroid Build Coastguard Worker
2771*da0073e9SAndroid Build Coastguard Worker        axis = -1
2772*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(100, 100, device=device)
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker        actual = a.logcumsumexp(axis)
2775*da0073e9SAndroid Build Coastguard Worker        expected = logcumsumexp(a, axis)
2776*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.dtype, actual.dtype)
2777*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.shape, actual.shape)
2778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
2779*da0073e9SAndroid Build Coastguard Worker
2780*da0073e9SAndroid Build Coastguard Worker        # check -inf and nan handling
2781*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-float('inf'), -float('inf'), 1.0, 1.0, float('inf'),
2782*da0073e9SAndroid Build Coastguard Worker                         float('inf'), float('nan'), 1.0, 1.0], device=device)
2783*da0073e9SAndroid Build Coastguard Worker        x2d = x.unsqueeze(0).expand(2, -1)
2784*da0073e9SAndroid Build Coastguard Worker
2785*da0073e9SAndroid Build Coastguard Worker        for inp in (x, x2d):
2786*da0073e9SAndroid Build Coastguard Worker            actual = inp.logcumsumexp(axis)
2787*da0073e9SAndroid Build Coastguard Worker            expected = logcumsumexp(inp, axis)
2788*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
2789*da0073e9SAndroid Build Coastguard Worker
2790*da0073e9SAndroid Build Coastguard Worker        # Check that out is actually inplace
2791*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 2, device=device)
2792*da0073e9SAndroid Build Coastguard Worker        inplace_out = torch.zeros(5, 2, device=device)
2793*da0073e9SAndroid Build Coastguard Worker
2794*da0073e9SAndroid Build Coastguard Worker        expected = logcumsumexp(b, axis)
2795*da0073e9SAndroid Build Coastguard Worker        torch.logcumsumexp(b, axis=axis, out=inplace_out)
2796*da0073e9SAndroid Build Coastguard Worker
2797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inplace_out, expected)
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker        # Check input and inplace_output type mismatch
2800*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 2, device=device, dtype=torch.float64)
2801*da0073e9SAndroid Build Coastguard Worker        inplace_out = torch.zeros(5, 2, device=device, dtype=torch.float32)
2802*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2803*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
2804*da0073e9SAndroid Build Coastguard Worker                'expected scalar_type Double but found Float'):
2805*da0073e9SAndroid Build Coastguard Worker            torch.logcumsumexp(b, axis, out=inplace_out)
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker    def _test_diff_numpy(self, t, dims=None):
2808*da0073e9SAndroid Build Coastguard Worker        # Helper for test_diff to compare with NumPy reference implementation
2809*da0073e9SAndroid Build Coastguard Worker        def to_np(t):
2810*da0073e9SAndroid Build Coastguard Worker            if t.dtype == torch.bfloat16:
2811*da0073e9SAndroid Build Coastguard Worker                return t.to(dtype=torch.float, device="cpu").numpy()
2812*da0073e9SAndroid Build Coastguard Worker            else:
2813*da0073e9SAndroid Build Coastguard Worker                return t.cpu().numpy()
2814*da0073e9SAndroid Build Coastguard Worker
2815*da0073e9SAndroid Build Coastguard Worker        for dim in dims if dims else range(t.dim()):
2816*da0073e9SAndroid Build Coastguard Worker            prepend = t.narrow(dim, 0, 1)
2817*da0073e9SAndroid Build Coastguard Worker            append = t.narrow(dim, 0, 1)
2818*da0073e9SAndroid Build Coastguard Worker            np_t = to_np(t)
2819*da0073e9SAndroid Build Coastguard Worker
2820*da0073e9SAndroid Build Coastguard Worker            # test when no prepend and append
2821*da0073e9SAndroid Build Coastguard Worker            for n in range(t.size(dim)):
2822*da0073e9SAndroid Build Coastguard Worker                actual = torch.diff(t, dim=dim, n=n)
2823*da0073e9SAndroid Build Coastguard Worker                expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n))
2824*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected.to(t.dtype))
2825*da0073e9SAndroid Build Coastguard Worker
2826*da0073e9SAndroid Build Coastguard Worker            # test when prepend and append's size along dim is 1
2827*da0073e9SAndroid Build Coastguard Worker            for n in range(1, t.size(dim) + 4):
2828*da0073e9SAndroid Build Coastguard Worker                actual = torch.diff(t, dim=dim, n=n, prepend=prepend, append=append)
2829*da0073e9SAndroid Build Coastguard Worker                expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=to_np(prepend), append=to_np(append)))
2830*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected.to(t.dtype))
2831*da0073e9SAndroid Build Coastguard Worker
2832*da0073e9SAndroid Build Coastguard Worker            # test when prepend and append's size along dim != 1
2833*da0073e9SAndroid Build Coastguard Worker            for n in range(1, t.size(dim) * 3):
2834*da0073e9SAndroid Build Coastguard Worker                actual = torch.diff(t, dim=dim, n=n, prepend=t, append=t)
2835*da0073e9SAndroid Build Coastguard Worker                expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=np_t, append=np_t))
2836*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected.to(t.dtype))
2837*da0073e9SAndroid Build Coastguard Worker
2838*da0073e9SAndroid Build Coastguard Worker    # All tensors appear contiguous on XLA
2839*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2840*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
2841*da0073e9SAndroid Build Coastguard Worker    def test_diff_noncontig(self, device, dtype):
2842*da0073e9SAndroid Build Coastguard Worker        shapes = (
2843*da0073e9SAndroid Build Coastguard Worker            (1,),
2844*da0073e9SAndroid Build Coastguard Worker            (1, 5),
2845*da0073e9SAndroid Build Coastguard Worker            (3, 5),
2846*da0073e9SAndroid Build Coastguard Worker            (1, 5, 1),
2847*da0073e9SAndroid Build Coastguard Worker            (2, 3, 5))
2848*da0073e9SAndroid Build Coastguard Worker
2849*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2850*da0073e9SAndroid Build Coastguard Worker            contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
2851*da0073e9SAndroid Build Coastguard Worker
2852*da0073e9SAndroid Build Coastguard Worker            non_contig = torch.empty(shape + (2, 2), device=device, dtype=dtype)[..., 0]
2853*da0073e9SAndroid Build Coastguard Worker            non_contig = non_contig.select(-1, -1)
2854*da0073e9SAndroid Build Coastguard Worker            non_contig.copy_(contig)
2855*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not non_contig.is_contiguous() or shape == (1,))
2856*da0073e9SAndroid Build Coastguard Worker
2857*da0073e9SAndroid Build Coastguard Worker            self._test_diff_numpy(non_contig)
2858*da0073e9SAndroid Build Coastguard Worker
2859*da0073e9SAndroid Build Coastguard Worker    # RngNormal not implemented for type f16 for XLA
2860*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool))
2861*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool))
2862*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool))
2863*da0073e9SAndroid Build Coastguard Worker    def test_diff(self, device, dtype):
2864*da0073e9SAndroid Build Coastguard Worker        shapes = (
2865*da0073e9SAndroid Build Coastguard Worker            (1,),
2866*da0073e9SAndroid Build Coastguard Worker            (1, 5),
2867*da0073e9SAndroid Build Coastguard Worker            (3, 5),
2868*da0073e9SAndroid Build Coastguard Worker            (1, 5, 1),
2869*da0073e9SAndroid Build Coastguard Worker            (2, 3, 5))
2870*da0073e9SAndroid Build Coastguard Worker
2871*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2872*da0073e9SAndroid Build Coastguard Worker            contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
2873*da0073e9SAndroid Build Coastguard Worker            self._test_diff_numpy(contig)
2874*da0073e9SAndroid Build Coastguard Worker
2875*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2, 3)
2876*da0073e9SAndroid Build Coastguard Worker
2877*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2878*da0073e9SAndroid Build Coastguard Worker                RuntimeError, 'diff expects prepend or append to be the same dimension as input'):
2879*da0073e9SAndroid Build Coastguard Worker            invalid_prepend = torch.tensor([1, 2, 3], device=device, dtype=dtype)
2880*da0073e9SAndroid Build Coastguard Worker            t.diff(dim=0, prepend=invalid_prepend)
2881*da0073e9SAndroid Build Coastguard Worker
2882*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2883*da0073e9SAndroid Build Coastguard Worker                RuntimeError, 'diff expects the shape of tensor to prepend or append to match that of input'):
2884*da0073e9SAndroid Build Coastguard Worker            invalid_prepend = torch.tensor([[0, 1]], device=device, dtype=dtype)
2885*da0073e9SAndroid Build Coastguard Worker            t.diff(dim=0, prepend=invalid_prepend)
2886*da0073e9SAndroid Build Coastguard Worker
2887*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2888*da0073e9SAndroid Build Coastguard Worker                RuntimeError, 'diff expects input to be at least one-dimensional'):
2889*da0073e9SAndroid Build Coastguard Worker            scalar = torch.tensor(2, device=device, dtype=dtype)
2890*da0073e9SAndroid Build Coastguard Worker            torch.diff(scalar)
2891*da0073e9SAndroid Build Coastguard Worker
2892*da0073e9SAndroid Build Coastguard Worker    # if the given input arg is not a list, it returns a list of single element: [arg]
2893*da0073e9SAndroid Build Coastguard Worker    def _wrap_to_list(self, input_array):
2894*da0073e9SAndroid Build Coastguard Worker        return input_array if isinstance(input_array, list) else [input_array]
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker    # To ensure inf, -inf, and nan values do not cause divergence between Numpy and PyTorch.
2897*da0073e9SAndroid Build Coastguard Worker    # There are two types of possible divergence:
2898*da0073e9SAndroid Build Coastguard Worker    # 1. When we compute a,b both real numbers and has very small absolute values (i.e. very near to 0.0)
2899*da0073e9SAndroid Build Coastguard Worker    # then, result of a/b be inf, -inf and nan, and this cause divergence.
2900*da0073e9SAndroid Build Coastguard Worker    # 2. When we are dividing complex numbers by zero. For example, when a = torch.tensor(3+5j) we have
2901*da0073e9SAndroid Build Coastguard Worker    # a/0 to be equal to nan + nan*j in PyTorch and inf + inf*j in Numpy.
2902*da0073e9SAndroid Build Coastguard Worker    def _inf_nan_preprocess(self, actual, expected):
2903*da0073e9SAndroid Build Coastguard Worker        for i in range(len(expected)):
2904*da0073e9SAndroid Build Coastguard Worker            expected[i] = np.nan_to_num(expected[i], nan=nan, posinf=nan, neginf=nan)
2905*da0073e9SAndroid Build Coastguard Worker            # nan_to_num is not defined for complex tensors in PyTorch.
2906*da0073e9SAndroid Build Coastguard Worker            if actual[i].dtype == torch.complex64 :
2907*da0073e9SAndroid Build Coastguard Worker                actual[i].real = torch.nan_to_num(actual[i].real, nan=nan, posinf=nan, neginf=nan)
2908*da0073e9SAndroid Build Coastguard Worker                actual[i].imag = torch.nan_to_num(actual[i].imag, nan=nan, posinf=nan, neginf=nan)
2909*da0073e9SAndroid Build Coastguard Worker            else:
2910*da0073e9SAndroid Build Coastguard Worker                actual[i] = torch.nan_to_num(actual[i], nan=nan, posinf=nan, neginf=nan)
2911*da0073e9SAndroid Build Coastguard Worker
2912*da0073e9SAndroid Build Coastguard Worker        return actual, expected
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2915*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.long, torch.float32, torch.complex64)
2916*da0073e9SAndroid Build Coastguard Worker    def test_gradient_all(self, device, dtype):
2917*da0073e9SAndroid Build Coastguard Worker        def create_scalar(shape):
2918*da0073e9SAndroid Build Coastguard Worker            return make_tensor((1,), device='cpu', dtype=dtype, low=1.).item()
2919*da0073e9SAndroid Build Coastguard Worker
2920*da0073e9SAndroid Build Coastguard Worker        def create_list(shape):
2921*da0073e9SAndroid Build Coastguard Worker            return make_tensor((len(shape),), device='cpu', dtype=dtype, low=1.).tolist()
2922*da0073e9SAndroid Build Coastguard Worker
2923*da0073e9SAndroid Build Coastguard Worker        def create_coordinate_tensors(shape):
2924*da0073e9SAndroid Build Coastguard Worker            tensor_list = []
2925*da0073e9SAndroid Build Coastguard Worker            for i in range(len(shape)):
2926*da0073e9SAndroid Build Coastguard Worker                tensor_list.append(make_tensor((shape[i],), device=device, dtype=dtype))
2927*da0073e9SAndroid Build Coastguard Worker            return tensor_list
2928*da0073e9SAndroid Build Coastguard Worker
2929*da0073e9SAndroid Build Coastguard Worker        def filter_shape(shape, dim):
2930*da0073e9SAndroid Build Coastguard Worker            filtered_shape = []
2931*da0073e9SAndroid Build Coastguard Worker            for i in range(len(dim)):
2932*da0073e9SAndroid Build Coastguard Worker                filtered_shape.append(shape[dim[i]])
2933*da0073e9SAndroid Build Coastguard Worker            return filtered_shape
2934*da0073e9SAndroid Build Coastguard Worker
2935*da0073e9SAndroid Build Coastguard Worker        # shape, dims format
2936*da0073e9SAndroid Build Coastguard Worker        test_cases = (
2937*da0073e9SAndroid Build Coastguard Worker            ((5,), (0,)),
2938*da0073e9SAndroid Build Coastguard Worker            ((4, 4), (0, 1)),
2939*da0073e9SAndroid Build Coastguard Worker            ((3, 3, 3), (-1, 0)),
2940*da0073e9SAndroid Build Coastguard Worker            ((4, 4, 4), (2,)),
2941*da0073e9SAndroid Build Coastguard Worker            ((4, 4, 4), (0, 1)),
2942*da0073e9SAndroid Build Coastguard Worker            ((4, 4, 4, 3), (0, 2, 3)),
2943*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 3, 4, 3), (1, 2)),
2944*da0073e9SAndroid Build Coastguard Worker            ((4, 3, 6, 5, 3), (2, 4)),
2945*da0073e9SAndroid Build Coastguard Worker            ((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)),
2946*da0073e9SAndroid Build Coastguard Worker            ((1, 3, 3), (1, 2)),
2947*da0073e9SAndroid Build Coastguard Worker            ((1, 5), (1,)),
2948*da0073e9SAndroid Build Coastguard Worker        )
2949*da0073e9SAndroid Build Coastguard Worker
2950*da0073e9SAndroid Build Coastguard Worker        for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2],
2951*da0073e9SAndroid Build Coastguard Worker                                                          (create_scalar, create_list, create_coordinate_tensors)):
2952*da0073e9SAndroid Build Coastguard Worker            shape, dims = case
2953*da0073e9SAndroid Build Coastguard Worker            # filter shape by dims before passing filtered shape to create_* functions
2954*da0073e9SAndroid Build Coastguard Worker            filtered_shape = filter_shape(shape, dims)
2955*da0073e9SAndroid Build Coastguard Worker
2956*da0073e9SAndroid Build Coastguard Worker            spacing = space_fn(filtered_shape)
2957*da0073e9SAndroid Build Coastguard Worker            t = make_tensor(shape, device=device, dtype=dtype, noncontiguous=not contig)
2958*da0073e9SAndroid Build Coastguard Worker            t_np = t.cpu().numpy()
2959*da0073e9SAndroid Build Coastguard Worker
2960*da0073e9SAndroid Build Coastguard Worker            actual = torch.gradient(t, spacing=spacing, dim=dims, edge_order=edge_order)
2961*da0073e9SAndroid Build Coastguard Worker            if space_fn == create_coordinate_tensors and spacing[0].device != 'cpu':
2962*da0073e9SAndroid Build Coastguard Worker                spacing = [space.cpu().detach().numpy() for space in spacing]
2963*da0073e9SAndroid Build Coastguard Worker            expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order)
2964*da0073e9SAndroid Build Coastguard Worker            actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected))
2965*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False)
2966*da0073e9SAndroid Build Coastguard Worker
2967*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2968*da0073e9SAndroid Build Coastguard Worker    @slowTestIf(TEST_WITH_TORCHINDUCTOR)
2969*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.long, torch.float32, torch.complex64)
2970*da0073e9SAndroid Build Coastguard Worker    def test_gradient_extreme_cases(self, device, dtype):
2971*da0073e9SAndroid Build Coastguard Worker        # Test behaviour for inf and nan values
2972*da0073e9SAndroid Build Coastguard Worker        actual = torch.gradient(torch.tensor([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan]))
2973*da0073e9SAndroid Build Coastguard Worker        expected = np.gradient(np.array([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan]))
2974*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, self._wrap_to_list(expected), exact_dtype=False)
2975*da0073e9SAndroid Build Coastguard Worker
2976*da0073e9SAndroid Build Coastguard Worker        # Test behaviour in very big tensors
2977*da0073e9SAndroid Build Coastguard Worker        large_size = 100000
2978*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((large_size,), dtype=dtype, device=device)
2979*da0073e9SAndroid Build Coastguard Worker        t_np = t.cpu().numpy()
2980*da0073e9SAndroid Build Coastguard Worker        coordinates_np = np.random.randn(large_size)
2981*da0073e9SAndroid Build Coastguard Worker        coordinates = [torch.tensor(coordinates_np, device=device)]
2982*da0073e9SAndroid Build Coastguard Worker        actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=1)
2983*da0073e9SAndroid Build Coastguard Worker        expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=1)]
2984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, exact_dtype=False)
2985*da0073e9SAndroid Build Coastguard Worker
2986*da0073e9SAndroid Build Coastguard Worker        actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=2)
2987*da0073e9SAndroid Build Coastguard Worker        expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=2)]
2988*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, exact_dtype=False)
2989*da0073e9SAndroid Build Coastguard Worker
2990*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2991*da0073e9SAndroid Build Coastguard Worker    def test_gradient_type_promotion(self, device):
2992*da0073e9SAndroid Build Coastguard Worker        inputs = (
2993*da0073e9SAndroid Build Coastguard Worker            make_tensor((4, 4), device=device, dtype=torch.float32),
2994*da0073e9SAndroid Build Coastguard Worker            make_tensor((4, 4), device=device, dtype=torch.complex64),
2995*da0073e9SAndroid Build Coastguard Worker            make_tensor((4, 4), device=device, dtype=torch.int64),
2996*da0073e9SAndroid Build Coastguard Worker        )
2997*da0073e9SAndroid Build Coastguard Worker
2998*da0073e9SAndroid Build Coastguard Worker        spacing = (
2999*da0073e9SAndroid Build Coastguard Worker            make_tensor((1,), device='cpu', dtype=torch.float32).item(),
3000*da0073e9SAndroid Build Coastguard Worker            make_tensor((1,), device='cpu', dtype=torch.int64).item(),
3001*da0073e9SAndroid Build Coastguard Worker            make_tensor((1,), device='cpu', dtype=torch.complex64).item(),
3002*da0073e9SAndroid Build Coastguard Worker            make_tensor((2,), device='cpu', dtype=torch.float32, low=0.1).tolist(),
3003*da0073e9SAndroid Build Coastguard Worker            make_tensor((2,), device='cpu', dtype=torch.int64, low=1).tolist(),
3004*da0073e9SAndroid Build Coastguard Worker            make_tensor((2,), device='cpu', dtype=torch.complex64).tolist(),
3005*da0073e9SAndroid Build Coastguard Worker            [make_tensor((4,), device=device, dtype=torch.float32),
3006*da0073e9SAndroid Build Coastguard Worker             make_tensor((4,), device=device, dtype=torch.float32)],
3007*da0073e9SAndroid Build Coastguard Worker            [make_tensor((4,), device=device, dtype=torch.int64),
3008*da0073e9SAndroid Build Coastguard Worker             make_tensor((4,), device=device, dtype=torch.int64)],
3009*da0073e9SAndroid Build Coastguard Worker            [make_tensor((4,), device=device, dtype=torch.complex64),
3010*da0073e9SAndroid Build Coastguard Worker             make_tensor((4,), device=device, dtype=torch.complex64)],
3011*da0073e9SAndroid Build Coastguard Worker        )
3012*da0073e9SAndroid Build Coastguard Worker
3013*da0073e9SAndroid Build Coastguard Worker        for input, spacing_or_coord, edge_order in product(inputs, spacing, [1, 2]):
3014*da0073e9SAndroid Build Coastguard Worker            input_np = input.cpu().numpy()
3015*da0073e9SAndroid Build Coastguard Worker            input_np = input.cpu().numpy()
3016*da0073e9SAndroid Build Coastguard Worker            actual = torch.gradient(input, spacing=spacing_or_coord, dim=(0, 1), edge_order=edge_order)
3017*da0073e9SAndroid Build Coastguard Worker            spacing_or_coord_wrapped = self._wrap_to_list(spacing_or_coord)
3018*da0073e9SAndroid Build Coastguard Worker            spacing_or_coord_np = []
3019*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(spacing_or_coord_wrapped[0]) and torch.device(spacing_or_coord_wrapped[0].device).type != 'cpu':
3020*da0073e9SAndroid Build Coastguard Worker                for i in range(len(spacing_or_coord_wrapped)):
3021*da0073e9SAndroid Build Coastguard Worker                    spacing_or_coord_np.append(spacing_or_coord_wrapped[i].detach().clone().cpu().numpy())
3022*da0073e9SAndroid Build Coastguard Worker            else:
3023*da0073e9SAndroid Build Coastguard Worker                spacing_or_coord_np = spacing_or_coord_wrapped
3024*da0073e9SAndroid Build Coastguard Worker            expected = np.gradient(input_np, *spacing_or_coord_np, axis=(0, 1), edge_order=edge_order)
3025*da0073e9SAndroid Build Coastguard Worker            if actual[0].dtype == torch.complex64 and input.dtype != torch.complex64:
3026*da0073e9SAndroid Build Coastguard Worker                for i in range(len(actual)):
3027*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual[i].real, expected[i].real, exact_dtype=False)
3028*da0073e9SAndroid Build Coastguard Worker                    # Type promotion fails on Numpy when spacing is given as complex number and input is given as real.
3029*da0073e9SAndroid Build Coastguard Worker                    # Result is given just as real number and all the imaginary parts to be equal to zero.
3030*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False)
3031*da0073e9SAndroid Build Coastguard Worker            else:
3032*da0073e9SAndroid Build Coastguard Worker                actual, expected = self._inf_nan_preprocess(list(actual), expected)
3033*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False)
3034*da0073e9SAndroid Build Coastguard Worker
3035*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3036*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.long, torch.float32, torch.complex64)
3037*da0073e9SAndroid Build Coastguard Worker    def test_gradient_spacing_list_length_error(self, device, dtype):
3038*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((2, 2), device=device, dtype=dtype)
3039*da0073e9SAndroid Build Coastguard Worker
3040*da0073e9SAndroid Build Coastguard Worker        spacing = (make_tensor((2,), device=device, dtype=dtype),)
3041*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3042*da0073e9SAndroid Build Coastguard Worker            torch.gradient(t, spacing=spacing)
3043*da0073e9SAndroid Build Coastguard Worker
3044*da0073e9SAndroid Build Coastguard Worker        spacing = (make_tensor((2,), device=device, dtype=dtype),) * 2
3045*da0073e9SAndroid Build Coastguard Worker        torch.gradient(t, spacing=spacing)
3046*da0073e9SAndroid Build Coastguard Worker
3047*da0073e9SAndroid Build Coastguard Worker        spacing = (make_tensor((2,), device=device, dtype=dtype),) * 3
3048*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3049*da0073e9SAndroid Build Coastguard Worker            torch.gradient(t, spacing=spacing)
3050*da0073e9SAndroid Build Coastguard Worker
3051*da0073e9SAndroid Build Coastguard Worker        spacing = (2,)
3052*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3053*da0073e9SAndroid Build Coastguard Worker            torch.gradient(t, spacing=spacing)
3054*da0073e9SAndroid Build Coastguard Worker
3055*da0073e9SAndroid Build Coastguard Worker        spacing = (2, 2)
3056*da0073e9SAndroid Build Coastguard Worker        torch.gradient(t, spacing=spacing)
3057*da0073e9SAndroid Build Coastguard Worker
3058*da0073e9SAndroid Build Coastguard Worker        spacing = (2, 2, 2)
3059*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'):
3060*da0073e9SAndroid Build Coastguard Worker            torch.gradient(t, spacing=spacing)
3061*da0073e9SAndroid Build Coastguard Worker
3062*da0073e9SAndroid Build Coastguard Worker    def _test_large_cum_fn_helper(self, x, fn):
3063*da0073e9SAndroid Build Coastguard Worker        expected = fn(x.cpu().float())
3064*da0073e9SAndroid Build Coastguard Worker        actual = fn(x).cpu().float()
3065*da0073e9SAndroid Build Coastguard Worker        # Avoid self.assertEqual to save memory.
3066*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(expected, actual)
3067*da0073e9SAndroid Build Coastguard Worker
3068*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration")
3069*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "psutil issue for largeTensorTest. Too large for Jetson.")
3070*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3071*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)  # only small dtype not to get oom
3072*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('25GB', device='cpu')
3073*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('4GB', device='cuda')
3074*da0073e9SAndroid Build Coastguard Worker    def test_large_cumsum(self, device, dtype):
3075*da0073e9SAndroid Build Coastguard Worker        # initialization to avoid overflow and half caveats
3076*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2**30 + 200, device=device, dtype=dtype)
3077*da0073e9SAndroid Build Coastguard Worker        x[::3] = -3
3078*da0073e9SAndroid Build Coastguard Worker        x[1::3] = 2
3079*da0073e9SAndroid Build Coastguard Worker        x[2::3] = 1
3080*da0073e9SAndroid Build Coastguard Worker        self._test_large_cum_fn_helper(x, lambda x: torch.cumsum(x, 0))
3081*da0073e9SAndroid Build Coastguard Worker
3082*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3083*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)  # only small dtype not to get oom
3084*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('25GB', device='cpu')
3085*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('4GB', device='cuda')
3086*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "psutil issue for largeTensorTest. Too large for Jetson.")
3087*da0073e9SAndroid Build Coastguard Worker    def test_large_cumprod(self, device, dtype):
3088*da0073e9SAndroid Build Coastguard Worker        # initialization to avoid overflow and half caveats
3089*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2**30 + 200, device=device, dtype=dtype)
3090*da0073e9SAndroid Build Coastguard Worker        x[::3] = 8
3091*da0073e9SAndroid Build Coastguard Worker        x[1::3] = .25
3092*da0073e9SAndroid Build Coastguard Worker        x[2::3] = .5
3093*da0073e9SAndroid Build Coastguard Worker        self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0))
3094*da0073e9SAndroid Build Coastguard Worker
3095*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Torchdynamo fails with unknown reason")
3096*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
3097*da0073e9SAndroid Build Coastguard Worker    def test_discontiguous_out_cumsum(self, device):
3098*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 8, device=device)
3099*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(4, 16, device=device)[:, ::2]
3100*da0073e9SAndroid Build Coastguard Worker        out = torch.cumsum(x, 0)
3101*da0073e9SAndroid Build Coastguard Worker        torch.cumsum(x, 0, out=y)
3102*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.is_contiguous())
3103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, y, atol=0., rtol=0.)
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker    def _test_cumminmax_helper(self, x, fn, expected_val, expected_ind):
3106*da0073e9SAndroid Build Coastguard Worker        val, ind = fn(x, -1)
3107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(val, expected_val, atol=0, rtol=0)
3108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ind, expected_ind, atol=0, rtol=0)
3109*da0073e9SAndroid Build Coastguard Worker        out_val = torch.empty_like(val).t().contiguous().t()
3110*da0073e9SAndroid Build Coastguard Worker        out_ind = torch.empty_like(ind).t().contiguous().t()
3111*da0073e9SAndroid Build Coastguard Worker        fn(x, -1, out=(out_val, out_ind))
3112*da0073e9SAndroid Build Coastguard Worker        # TODO: Fix this. It reproduces with aot_eager too, and looks like a functionalization bug.
3113*da0073e9SAndroid Build Coastguard Worker        # (the problematic case seems rare, as we're calling an out= op directly from user code,
3114*da0073e9SAndroid Build Coastguard Worker        # where the passed-in out tensors are non-contiguous).
3115*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_TORCHINDUCTOR:
3116*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(out_val.is_contiguous())
3117*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(out_ind.is_contiguous())
3118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_val, expected_val, atol=0, rtol=0)
3119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ind, expected_ind, atol=0, rtol=0)
3120*da0073e9SAndroid Build Coastguard Worker
3121*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
3122*da0073e9SAndroid Build Coastguard Worker    def test_cummax_discontiguous(self, device):
3123*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[0, 1, 2, 3, 2, 1], [4, 5, 6, 5, 6, 7]], device=device, dtype=torch.float).t().contiguous().t()
3124*da0073e9SAndroid Build Coastguard Worker        expected_val = torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 6, 6, 7]], device=device, dtype=torch.float)
3125*da0073e9SAndroid Build Coastguard Worker        expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 2, 4, 5]], device=device, dtype=torch.long)
3126*da0073e9SAndroid Build Coastguard Worker        self._test_cumminmax_helper(x, torch.cummax, expected_val, expected_ind)
3127*da0073e9SAndroid Build Coastguard Worker
3128*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
3129*da0073e9SAndroid Build Coastguard Worker    def test_cummin_discontiguous(self, device):
3130*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[3, 2, 1, 0, 1, 2], [7, 6, 5, 4, 5, 2]], device=device, dtype=torch.float).t().contiguous().t()
3131*da0073e9SAndroid Build Coastguard Worker        expected_val = torch.tensor([[3, 2, 1, 0, 0, 0], [7, 6, 5, 4, 4, 2]], device=device, dtype=torch.float)
3132*da0073e9SAndroid Build Coastguard Worker        expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 3, 3, 5]], device=device, dtype=torch.long)
3133*da0073e9SAndroid Build Coastguard Worker        self._test_cumminmax_helper(x, torch.cummin, expected_val, expected_ind)
3134*da0073e9SAndroid Build Coastguard Worker
3135*da0073e9SAndroid Build Coastguard Worker    def test_bool_tensor_value_change(self, device):
3136*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([True, False], dtype=torch.bool, device=device)
3137*da0073e9SAndroid Build Coastguard Worker        x[0] = False
3138*da0073e9SAndroid Build Coastguard Worker        x[1] = True
3139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device))
3140*da0073e9SAndroid Build Coastguard Worker
3141*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to shape ops test suite
3142*da0073e9SAndroid Build Coastguard Worker    def test_unfold_all_devices_and_dtypes(self, device):
3143*da0073e9SAndroid Build Coastguard Worker        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3144*da0073e9SAndroid Build Coastguard Worker
3145*da0073e9SAndroid Build Coastguard Worker            if dt == torch.bool:
3146*da0073e9SAndroid Build Coastguard Worker                x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
3147*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
3148*da0073e9SAndroid Build Coastguard Worker            else:
3149*da0073e9SAndroid Build Coastguard Worker                x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
3150*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
3151*da0073e9SAndroid Build Coastguard Worker
3152*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to shape ops test suite
3153*da0073e9SAndroid Build Coastguard Worker    def test_unfold_scalars(self, device):
3154*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.5, device=device)
3155*da0073e9SAndroid Build Coastguard Worker        # unfold on a 0-dimensional tensor should always return a 1-d dimensional
3156*da0073e9SAndroid Build Coastguard Worker        # tensor of shape [size] (i.e., the second parameter to unfold)
3157*da0073e9SAndroid Build Coastguard Worker
3158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1))
3159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2))
3160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1))
3161*da0073e9SAndroid Build Coastguard Worker
3162*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to data movement test suite
3163*da0073e9SAndroid Build Coastguard Worker    def test_copy_all_dtypes_and_devices(self, device):
3164*da0073e9SAndroid Build Coastguard Worker        from copy import copy
3165*da0073e9SAndroid Build Coastguard Worker        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3166*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device)
3167*da0073e9SAndroid Build Coastguard Worker            x_clone = x.clone()
3168*da0073e9SAndroid Build Coastguard Worker            y = copy(x)
3169*da0073e9SAndroid Build Coastguard Worker            y.fill_(1)
3170*da0073e9SAndroid Build Coastguard Worker            # copy is a shallow copy, only copies the tensor view,
3171*da0073e9SAndroid Build Coastguard Worker            # not the data
3172*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
3173*da0073e9SAndroid Build Coastguard Worker
3174*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3175*da0073e9SAndroid Build Coastguard Worker    def test_bfloat16_neg_abs(self, device):
3176*da0073e9SAndroid Build Coastguard Worker        src = torch.randn(256)
3177*da0073e9SAndroid Build Coastguard Worker        src[0] = torch.nan
3178*da0073e9SAndroid Build Coastguard Worker        src[1] = -torch.nan
3179*da0073e9SAndroid Build Coastguard Worker        src[2] = torch.inf
3180*da0073e9SAndroid Build Coastguard Worker        src[3] = -torch.inf
3181*da0073e9SAndroid Build Coastguard Worker        src_bf16 = src.bfloat16()
3182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src.neg().bfloat16(), src_bf16.neg())
3183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src.abs().bfloat16(), src_bf16.abs())
3184*da0073e9SAndroid Build Coastguard Worker
3185*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3186*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
3187*da0073e9SAndroid Build Coastguard Worker    def test_reduced_type_float_copy(self, device, dtype):
3188*da0073e9SAndroid Build Coastguard Worker        for shape in [(20, 7), (249, 137), (1029, 917), (1, 7, 19, 17), (3, 77, 1091)]:
3189*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(shape, dtype=torch.float, device=device)
3190*da0073e9SAndroid Build Coastguard Worker            out1 = input.to(dtype=dtype)
3191*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, out1, atol=None, rtol=None, exact_dtype=False)
3192*da0073e9SAndroid Build Coastguard Worker            out2 = out1.to(torch.float)
3193*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False)
3194*da0073e9SAndroid Build Coastguard Worker
3195*da0073e9SAndroid Build Coastguard Worker            input_s = input[..., ::2, :]
3196*da0073e9SAndroid Build Coastguard Worker            out1 = input_s.to(dtype=dtype)
3197*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_s, out1, atol=None, rtol=None, exact_dtype=False)
3198*da0073e9SAndroid Build Coastguard Worker            out2 = out1.to(torch.float)
3199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False)
3200*da0073e9SAndroid Build Coastguard Worker
3201*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to data movement test suite
3202*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3203*da0073e9SAndroid Build Coastguard Worker    def test_copy_math_view(self, device):
3204*da0073e9SAndroid Build Coastguard Worker        for dst_dtype, src_dtype in [
3205*da0073e9SAndroid Build Coastguard Worker                (torch.float32, torch.float32),
3206*da0073e9SAndroid Build Coastguard Worker                (torch.float64, torch.float32),
3207*da0073e9SAndroid Build Coastguard Worker                (torch.int64, torch.int32),
3208*da0073e9SAndroid Build Coastguard Worker                (torch.complex128, torch.complex64),
3209*da0073e9SAndroid Build Coastguard Worker        ]:
3210*da0073e9SAndroid Build Coastguard Worker            src = make_tensor((100,), dtype=src_dtype, device=device)
3211*da0073e9SAndroid Build Coastguard Worker            dst = torch.empty(100, dtype=dst_dtype, device=device)
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker            dst.copy_(src)
3214*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src, exact_dtype=False)
3215*da0073e9SAndroid Build Coastguard Worker
3216*da0073e9SAndroid Build Coastguard Worker            dst.copy_(src._neg_view())
3217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src.neg(), exact_dtype=False)
3218*da0073e9SAndroid Build Coastguard Worker
3219*da0073e9SAndroid Build Coastguard Worker            dst._neg_view().copy_(torch._neg_view(src))
3220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src, exact_dtype=False)
3221*da0073e9SAndroid Build Coastguard Worker
3222*da0073e9SAndroid Build Coastguard Worker            dst._neg_view().copy_(src)
3223*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src.neg(), exact_dtype=False)
3224*da0073e9SAndroid Build Coastguard Worker
3225*da0073e9SAndroid Build Coastguard Worker            # issue: https://github.com/pytorch/pytorch/issues/106051
3226*da0073e9SAndroid Build Coastguard Worker            dst._neg_view().copy_(dst)
3227*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src, exact_dtype=False)
3228*da0073e9SAndroid Build Coastguard Worker
3229*da0073e9SAndroid Build Coastguard Worker        for dst_dtype, src_dtype in [
3230*da0073e9SAndroid Build Coastguard Worker                (torch.complex64, torch.complex64),
3231*da0073e9SAndroid Build Coastguard Worker                (torch.complex128, torch.complex64),
3232*da0073e9SAndroid Build Coastguard Worker        ]:
3233*da0073e9SAndroid Build Coastguard Worker            src = make_tensor((100,), dtype=src_dtype, device=device)
3234*da0073e9SAndroid Build Coastguard Worker            dst = torch.empty(100, dtype=dst_dtype, device=device)
3235*da0073e9SAndroid Build Coastguard Worker
3236*da0073e9SAndroid Build Coastguard Worker            dst.conj().copy_(src)
3237*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src.conj_physical(), exact_dtype=False)
3238*da0073e9SAndroid Build Coastguard Worker
3239*da0073e9SAndroid Build Coastguard Worker            dst.conj().copy_(src._neg_view())
3240*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src.neg().conj_physical(), exact_dtype=False)
3241*da0073e9SAndroid Build Coastguard Worker
3242*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to data movement test suite
3243*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3244*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float32, torch.complex64)
3245*da0073e9SAndroid Build Coastguard Worker    def test_copy_transpose_math_view(self, device, dtype):
3246*da0073e9SAndroid Build Coastguard Worker        src = make_tensor((100, 100), dtype=dtype, device=device).transpose(0, 1)
3247*da0073e9SAndroid Build Coastguard Worker        dst = torch.empty((100, 100), dtype=dtype, device=device)
3248*da0073e9SAndroid Build Coastguard Worker
3249*da0073e9SAndroid Build Coastguard Worker        dst._neg_view().copy_(src)
3250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, -src)
3251*da0073e9SAndroid Build Coastguard Worker        dst._neg_view().copy_(src._neg_view())
3252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, src)
3253*da0073e9SAndroid Build Coastguard Worker        dst.copy_(src._neg_view())
3254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, -src)
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
3257*da0073e9SAndroid Build Coastguard Worker            dst.conj().copy_(src)
3258*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src.conj_physical())
3259*da0073e9SAndroid Build Coastguard Worker            dst.conj().copy_(src.conj())
3260*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src)
3261*da0073e9SAndroid Build Coastguard Worker            dst.copy_(src.conj())
3262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, src.conj_physical())
3263*da0073e9SAndroid Build Coastguard Worker
3264*da0073e9SAndroid Build Coastguard Worker    def test_clone_all_dtypes_and_devices(self, device):
3265*da0073e9SAndroid Build Coastguard Worker        for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
3266*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor((1, 1), dtype=dt, device=device)
3267*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
3268*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
3269*da0073e9SAndroid Build Coastguard Worker
3270*da0073e9SAndroid Build Coastguard Worker    def test_clone_zero_stride_dim(self, device):
3271*da0073e9SAndroid Build Coastguard Worker        # stride zero, size 1 axis, not contiguous
3272*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
3273*da0073e9SAndroid Build Coastguard Worker        y = x.as_strided([2, 1, 5], [1, 0, 2])
3274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y.clone())
3275*da0073e9SAndroid Build Coastguard Worker
3276*da0073e9SAndroid Build Coastguard Worker    def test_clone_not_memory_dense(self):
3277*da0073e9SAndroid Build Coastguard Worker        # github issue: https://github.com/pytorch/pytorch/issues/64176
3278*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 8).t()[::2, ::2]
3279*da0073e9SAndroid Build Coastguard Worker        y = x.clone()
3280*da0073e9SAndroid Build Coastguard Worker        # should retain permutation after densification
3281*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.stride() == (1, 4))
3282*da0073e9SAndroid Build Coastguard Worker
3283*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to elementwise ternary test suite
3284*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
3285*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(get_all_math_dtypes('cpu')))
3286*da0073e9SAndroid Build Coastguard Worker    def test_addcmul(self, device, dtype):
3287*da0073e9SAndroid Build Coastguard Worker        # Returns floating or integral scalar corresponding to dtype
3288*da0073e9SAndroid Build Coastguard Worker        def _number(floating, integer, dtype):
3289*da0073e9SAndroid Build Coastguard Worker            if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
3290*da0073e9SAndroid Build Coastguard Worker                return floating
3291*da0073e9SAndroid Build Coastguard Worker            elif dtype in [torch.cfloat, torch.cdouble]:
3292*da0073e9SAndroid Build Coastguard Worker                return floating * (1 + 1j)
3293*da0073e9SAndroid Build Coastguard Worker            else:
3294*da0073e9SAndroid Build Coastguard Worker                return integer
3295*da0073e9SAndroid Build Coastguard Worker
3296*da0073e9SAndroid Build Coastguard Worker        def rand_tensor(size, dtype, device):
3297*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
3298*da0073e9SAndroid Build Coastguard Worker                return torch.rand(size=size, dtype=dtype, device=device)
3299*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.uint8:
3300*da0073e9SAndroid Build Coastguard Worker                return torch.randint(1, 5, size=size, dtype=dtype, device=device)
3301*da0073e9SAndroid Build Coastguard Worker            else:
3302*da0073e9SAndroid Build Coastguard Worker                return torch.randint(-5, 5, size=size, dtype=dtype, device=device)
3303*da0073e9SAndroid Build Coastguard Worker
3304*da0073e9SAndroid Build Coastguard Worker        a = rand_tensor((2, 2), dtype=dtype, device=device)
3305*da0073e9SAndroid Build Coastguard Worker        b = rand_tensor((2, 2), dtype=dtype, device=device)
3306*da0073e9SAndroid Build Coastguard Worker        c = rand_tensor((2, 2), dtype=dtype, device=device)
3307*da0073e9SAndroid Build Coastguard Worker
3308*da0073e9SAndroid Build Coastguard Worker        alpha = _number(0.5, 3, dtype)
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker        actual = torch.addcmul(a, b, c, value=alpha)
3311*da0073e9SAndroid Build Coastguard Worker        expected = a + alpha * b * c
3312*da0073e9SAndroid Build Coastguard Worker
3313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
3316*da0073e9SAndroid Build Coastguard Worker                UserWarning, "This overload of addcmul is deprecated"):
3317*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, torch.addcmul(a, alpha, b, c))
3318*da0073e9SAndroid Build Coastguard Worker
3319*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and dtype == torch.half:
3320*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([60000.0], device=device, dtype=dtype)
3321*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([60000.0], device=device, dtype=dtype)
3322*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor([2.0], device=device, dtype=dtype)
3323*da0073e9SAndroid Build Coastguard Worker            out = torch.addcmul(a, b, c, value=-1)
3324*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not (out.isnan() or out.isinf()))
3325*da0073e9SAndroid Build Coastguard Worker
3326*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to shape ops test suite
3327*da0073e9SAndroid Build Coastguard Worker    def test_narrow_empty(self, device):
3328*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4, device=device)
3329*da0073e9SAndroid Build Coastguard Worker        for d in range(x.dim()):
3330*da0073e9SAndroid Build Coastguard Worker            y = x.narrow(d, x.size(d), 0)
3331*da0073e9SAndroid Build Coastguard Worker            sz = list(x.size())
3332*da0073e9SAndroid Build Coastguard Worker            sz[d] = 0
3333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sz, y.size())
3334*da0073e9SAndroid Build Coastguard Worker
3335*da0073e9SAndroid Build Coastguard Worker    def test_narrow_copy_non_contiguous(self, device):
3336*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/91690.
3337*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(10, 2, device=device).movedim(-1, 0)
3338*da0073e9SAndroid Build Coastguard Worker        expected = torch.narrow_copy(inp.contiguous(), 1, 0, 10)
3339*da0073e9SAndroid Build Coastguard Worker        actual = torch.narrow_copy(inp, 1, 0, 10)
3340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
3341*da0073e9SAndroid Build Coastguard Worker
3342*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to indexing test suite
3343*da0073e9SAndroid Build Coastguard Worker    @parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
3344*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
3345*da0073e9SAndroid Build Coastguard Worker    def test_index_reduce(self, device, dtype, reduce):
3346*da0073e9SAndroid Build Coastguard Worker        size = (3, 4, 5)
3347*da0073e9SAndroid Build Coastguard Worker        index_dtypes = [torch.int, torch.long]
3348*da0073e9SAndroid Build Coastguard Worker        include_selfs = [True, False]
3349*da0073e9SAndroid Build Coastguard Worker        amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max
3350*da0073e9SAndroid Build Coastguard Worker        amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min
3351*da0073e9SAndroid Build Coastguard Worker        reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init}
3352*da0073e9SAndroid Build Coastguard Worker
3353*da0073e9SAndroid Build Coastguard Worker        for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3):
3354*da0073e9SAndroid Build Coastguard Worker            for idx_dtype, include_self in product(index_dtypes, include_selfs):
3355*da0073e9SAndroid Build Coastguard Worker                for dim in range(len(size)):
3356*da0073e9SAndroid Build Coastguard Worker                    num_src = np.random.randint(10)
3357*da0073e9SAndroid Build Coastguard Worker                    num_dest = size[dim]
3358*da0073e9SAndroid Build Coastguard Worker                    dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig)
3359*da0073e9SAndroid Build Coastguard Worker                    src_size = size[:dim] + (num_src,) + size[dim + 1:]
3360*da0073e9SAndroid Build Coastguard Worker                    src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig)
3361*da0073e9SAndroid Build Coastguard Worker                    idx = torch.testing.make_tensor(
3362*da0073e9SAndroid Build Coastguard Worker                        num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig
3363*da0073e9SAndroid Build Coastguard Worker                    )
3364*da0073e9SAndroid Build Coastguard Worker                    expected = dest.clone()
3365*da0073e9SAndroid Build Coastguard Worker                    dest.index_reduce_(dim, idx, src, reduce, include_self=include_self)
3366*da0073e9SAndroid Build Coastguard Worker                    # fill rows in idx with reduction inits if include_self=False
3367*da0073e9SAndroid Build Coastguard Worker                    if (not include_self):
3368*da0073e9SAndroid Build Coastguard Worker                        expected.index_fill_(dim, idx.long(), reduction_init[reduce])
3369*da0073e9SAndroid Build Coastguard Worker                    expected = expected.transpose(0, dim)
3370*da0073e9SAndroid Build Coastguard Worker                    src = src.transpose(0, dim)
3371*da0073e9SAndroid Build Coastguard Worker                    for i in range(num_src):
3372*da0073e9SAndroid Build Coastguard Worker                        if reduce == 'prod':
3373*da0073e9SAndroid Build Coastguard Worker                            expected[idx[i]] *= src[i]
3374*da0073e9SAndroid Build Coastguard Worker                        elif reduce == 'amin':
3375*da0073e9SAndroid Build Coastguard Worker                            torch.minimum(expected[idx[i]], src[i], out=expected[idx[i]])
3376*da0073e9SAndroid Build Coastguard Worker                        elif reduce == 'amax':
3377*da0073e9SAndroid Build Coastguard Worker                            torch.maximum(expected[idx[i]], src[i], out=expected[idx[i]])
3378*da0073e9SAndroid Build Coastguard Worker                        else:
3379*da0073e9SAndroid Build Coastguard Worker                            expected[idx[i]] += src[i]
3380*da0073e9SAndroid Build Coastguard Worker                    if reduce == 'mean':
3381*da0073e9SAndroid Build Coastguard Worker                        counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected)
3382*da0073e9SAndroid Build Coastguard Worker                        counts.index_add_(0, idx, torch.ones_like(src))
3383*da0073e9SAndroid Build Coastguard Worker                        counts.masked_fill_(counts == 0, 1)
3384*da0073e9SAndroid Build Coastguard Worker                        if (dtype.is_floating_point):
3385*da0073e9SAndroid Build Coastguard Worker                            expected.div_(counts)
3386*da0073e9SAndroid Build Coastguard Worker                        else:
3387*da0073e9SAndroid Build Coastguard Worker                            expected.div_(counts, rounding_mode="floor")
3388*da0073e9SAndroid Build Coastguard Worker                    expected = expected.transpose(0, dim)
3389*da0073e9SAndroid Build Coastguard Worker
3390*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(dest, expected)
3391*da0073e9SAndroid Build Coastguard Worker
3392*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3393*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3394*da0073e9SAndroid Build Coastguard Worker    def test_index_copy(self, device, dtype):
3395*da0073e9SAndroid Build Coastguard Worker        # We just test for num_copy <= num_dest, as otherwise there are repeated indices
3396*da0073e9SAndroid Build Coastguard Worker        # and the behavior is undefined
3397*da0073e9SAndroid Build Coastguard Worker        num_copy, num_dest = 3, 5
3398*da0073e9SAndroid Build Coastguard Worker
3399*da0073e9SAndroid Build Coastguard Worker        def make_arg(batch_sizes, n, dim, contig):
3400*da0073e9SAndroid Build Coastguard Worker            size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
3401*da0073e9SAndroid Build Coastguard Worker            return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
3402*da0073e9SAndroid Build Coastguard Worker
3403*da0073e9SAndroid Build Coastguard Worker        def ref_index_copy(tgt, dim, idx, src):
3404*da0073e9SAndroid Build Coastguard Worker            for i in range(idx.size(0)):
3405*da0073e9SAndroid Build Coastguard Worker                idx_dest = dim * (slice(None),) + (idx[i],)
3406*da0073e9SAndroid Build Coastguard Worker                idx_src = dim * (slice(None),) + (i,)
3407*da0073e9SAndroid Build Coastguard Worker                tgt[idx_dest] = src[idx_src]
3408*da0073e9SAndroid Build Coastguard Worker
3409*da0073e9SAndroid Build Coastguard Worker        # More thorough testing as in index_add
3410*da0073e9SAndroid Build Coastguard Worker        for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
3411*da0073e9SAndroid Build Coastguard Worker            for other_sizes in ((), (4, 5)):
3412*da0073e9SAndroid Build Coastguard Worker                for dim in range(len(other_sizes)):
3413*da0073e9SAndroid Build Coastguard Worker                    dest = make_arg(other_sizes, num_dest, dim, dest_contig)
3414*da0073e9SAndroid Build Coastguard Worker                    src = make_arg(other_sizes, num_copy, dim, src_contig)
3415*da0073e9SAndroid Build Coastguard Worker                    idx = torch.randperm(num_dest, dtype=torch.int64, device=device)[:num_copy]
3416*da0073e9SAndroid Build Coastguard Worker                    if not index_contig:
3417*da0073e9SAndroid Build Coastguard Worker                        idx = torch.repeat_interleave(idx, 2, dim=-1)
3418*da0073e9SAndroid Build Coastguard Worker                        idx = idx[..., ::2]
3419*da0073e9SAndroid Build Coastguard Worker                    dest2 = dest.clone()
3420*da0073e9SAndroid Build Coastguard Worker                    dest.index_copy_(dim, idx, src)
3421*da0073e9SAndroid Build Coastguard Worker                    ref_index_copy(dest2, dim, idx, src)
3422*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(dest, dest2)
3423*da0073e9SAndroid Build Coastguard Worker
3424*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3425*da0073e9SAndroid Build Coastguard Worker    # onlyNativeDeviceTypes due to an XLA error:
3426*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/53256
3427*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3428*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3429*da0073e9SAndroid Build Coastguard Worker    def test_index_copy_scalars(self, device, dtype):
3430*da0073e9SAndroid Build Coastguard Worker        # Create the 8 possible combinations of scalar sizes for target / index / source
3431*da0073e9SAndroid Build Coastguard Worker        scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None),
3432*da0073e9SAndroid Build Coastguard Worker                    make_tensor(size_i, dtype=torch.int64, device=device, low=0, high=1),
3433*da0073e9SAndroid Build Coastguard Worker                    make_tensor(size_s, dtype=dtype, device=device, low=None, high=None))
3434*da0073e9SAndroid Build Coastguard Worker                   for size_t, size_i, size_s in product([(), (1,)], repeat=3))
3435*da0073e9SAndroid Build Coastguard Worker        for target, idx, source in scalars:
3436*da0073e9SAndroid Build Coastguard Worker            target.index_copy_(0, idx, source)
3437*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(target.item(), source.item())
3438*da0073e9SAndroid Build Coastguard Worker
3439*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3440*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3441*da0073e9SAndroid Build Coastguard Worker    def test_errors_index_copy(self, device):
3442*da0073e9SAndroid Build Coastguard Worker        # We do not test the GPU as the CUDA_ASSERT would break the CUDA context
3443*da0073e9SAndroid Build Coastguard Worker        idx_dim = 8
3444*da0073e9SAndroid Build Coastguard Worker        tgt_dim = 5
3445*da0073e9SAndroid Build Coastguard Worker        batch_dim = 3
3446*da0073e9SAndroid Build Coastguard Worker
3447*da0073e9SAndroid Build Coastguard Worker        # Too large of an index
3448*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(batch_dim, tgt_dim, device=device)
3449*da0073e9SAndroid Build Coastguard Worker        idx = torch.full((idx_dim,), tgt_dim, device=device)
3450*da0073e9SAndroid Build Coastguard Worker        c = torch.zeros(batch_dim, idx_dim, device=device)
3451*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
3452*da0073e9SAndroid Build Coastguard Worker            a.index_copy_(1, idx, c)
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker        # Too small (negative indices)
3455*da0073e9SAndroid Build Coastguard Worker        idx = torch.full((idx_dim,), -1, device=device)
3456*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
3457*da0073e9SAndroid Build Coastguard Worker            a.index_copy_(1, idx, c)
3458*da0073e9SAndroid Build Coastguard Worker
3459*da0073e9SAndroid Build Coastguard Worker        # Too small (very negative indices) - they should be unsupported even
3460*da0073e9SAndroid Build Coastguard Worker        # when support for negative indices is implemented for index_copy_
3461*da0073e9SAndroid Build Coastguard Worker        idx = torch.full((idx_dim,), -tgt_dim - 1, device=device)
3462*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
3463*da0073e9SAndroid Build Coastguard Worker            a.index_copy_(1, idx, c)
3464*da0073e9SAndroid Build Coastguard Worker
3465*da0073e9SAndroid Build Coastguard Worker    def _prepare_data_for_index_copy_and_add_deterministic(
3466*da0073e9SAndroid Build Coastguard Worker        self, dim: int, device: torch.device
3467*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3468*da0073e9SAndroid Build Coastguard Worker        assert (dim >= 0 and dim < 3)
3469*da0073e9SAndroid Build Coastguard Worker        a = [5, 4, 3]
3470*da0073e9SAndroid Build Coastguard Worker        a[dim] = 2000
3471*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(a, device=device)
3472*da0073e9SAndroid Build Coastguard Worker        b = a.copy()
3473*da0073e9SAndroid Build Coastguard Worker        elems = a[dim] * 20
3474*da0073e9SAndroid Build Coastguard Worker        b[dim] = elems
3475*da0073e9SAndroid Build Coastguard Worker        src = torch.rand(b, device=device)
3476*da0073e9SAndroid Build Coastguard Worker        index = torch.randint(a[dim], (elems,), device=device)
3477*da0073e9SAndroid Build Coastguard Worker        return (x, index, src)
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3480*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3481*da0073e9SAndroid Build Coastguard Worker    def test_index_copy_deterministic(self, device: torch.device) -> None:
3482*da0073e9SAndroid Build Coastguard Worker        for dim in range(3):
3483*da0073e9SAndroid Build Coastguard Worker            x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
3484*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(True):
3485*da0073e9SAndroid Build Coastguard Worker                y0 = torch.index_copy(x, dim, index, src)
3486*da0073e9SAndroid Build Coastguard Worker
3487*da0073e9SAndroid Build Coastguard Worker            x0 = x.clone().detach()
3488*da0073e9SAndroid Build Coastguard Worker            index_list = index.tolist()
3489*da0073e9SAndroid Build Coastguard Worker            for i in range(len(index_list)):
3490*da0073e9SAndroid Build Coastguard Worker                if dim == 0:
3491*da0073e9SAndroid Build Coastguard Worker                    x0[index_list[i], :, :] = src[i, :, :]
3492*da0073e9SAndroid Build Coastguard Worker                elif dim == 1:
3493*da0073e9SAndroid Build Coastguard Worker                    x0[:, index_list[i], :] = src[:, i, :]
3494*da0073e9SAndroid Build Coastguard Worker                elif dim == 2:
3495*da0073e9SAndroid Build Coastguard Worker                    x0[:, :, index_list[i]] = src[:, :, i]
3496*da0073e9SAndroid Build Coastguard Worker
3497*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x0, y0, atol=0, rtol=0)
3498*da0073e9SAndroid Build Coastguard Worker
3499*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3500*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3501*da0073e9SAndroid Build Coastguard Worker    def test_index_add_deterministic(self, device: torch.device) -> None:
3502*da0073e9SAndroid Build Coastguard Worker        for dim in range(3):
3503*da0073e9SAndroid Build Coastguard Worker            x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
3504*da0073e9SAndroid Build Coastguard Worker            alpha = random.random() + 1
3505*da0073e9SAndroid Build Coastguard Worker            # on CPU it should be deterministic regardless of the deterministic mode
3506*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(True):
3507*da0073e9SAndroid Build Coastguard Worker                y0 = torch.index_add(x, dim, index, src, alpha=alpha)
3508*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
3509*da0073e9SAndroid Build Coastguard Worker                    y = torch.index_add(x, dim, index, src, alpha=alpha)
3510*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, y0, atol=0, rtol=0)
3511*da0073e9SAndroid Build Coastguard Worker
3512*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(False):
3513*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
3514*da0073e9SAndroid Build Coastguard Worker                    y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
3515*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the put operator
3518*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3519*da0073e9SAndroid Build Coastguard Worker    def test_index_put_non_accumulate_deterministic(self, device) -> None:
3520*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(True):
3521*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
3522*da0073e9SAndroid Build Coastguard Worker                m = random.randint(10, 20)
3523*da0073e9SAndroid Build Coastguard Worker                elems = random.randint(20000, 30000)
3524*da0073e9SAndroid Build Coastguard Worker                values = torch.rand(elems, device=device)
3525*da0073e9SAndroid Build Coastguard Worker                indices = torch.randint(m, (elems,), device=device)
3526*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(m, device=device)
3527*da0073e9SAndroid Build Coastguard Worker                output = input.index_put((indices,), values, accumulate=False)
3528*da0073e9SAndroid Build Coastguard Worker
3529*da0073e9SAndroid Build Coastguard Worker                input_list = input.tolist()
3530*da0073e9SAndroid Build Coastguard Worker                indices_list = indices.tolist()
3531*da0073e9SAndroid Build Coastguard Worker                values_list = values.tolist()
3532*da0073e9SAndroid Build Coastguard Worker                for i, v in zip(indices_list, values_list):
3533*da0073e9SAndroid Build Coastguard Worker                    input_list[i] = v
3534*da0073e9SAndroid Build Coastguard Worker
3535*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, input_list)
3536*da0073e9SAndroid Build Coastguard Worker
3537*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3538*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3539*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
3540*da0073e9SAndroid Build Coastguard Worker    def test_index_fill(self, device, dtype):
3541*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
3542*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([0], device=device)
3543*da0073e9SAndroid Build Coastguard Worker        x.index_fill_(1, index, 0)
3544*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
3545*da0073e9SAndroid Build Coastguard Worker        if not x.is_complex() and not device == "meta":
3546*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Scalar"):
3547*da0073e9SAndroid Build Coastguard Worker                x.index_fill_(1, index, 1 + 1j)
3548*da0073e9SAndroid Build Coastguard Worker        # Make sure that the result stays 0-dim while applied to
3549*da0073e9SAndroid Build Coastguard Worker        # a 0-dim input
3550*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1, dtype=dtype, device=device)
3551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, x.index_fill(0, index, -1).dim())
3552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, x.index_fill_(0, index, -1).dim())
3553*da0073e9SAndroid Build Coastguard Worker
3554*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test indexing
3555*da0073e9SAndroid Build Coastguard Worker    # The test fails for zero-dimensional tensors on XLA
3556*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3557*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3558*da0073e9SAndroid Build Coastguard Worker    def test_index_select(self, device, dtype):
3559*da0073e9SAndroid Build Coastguard Worker        num_src, num_out = 3, 5
3560*da0073e9SAndroid Build Coastguard Worker
3561*da0073e9SAndroid Build Coastguard Worker        def make_arg(batch_sizes, n, dim, contig):
3562*da0073e9SAndroid Build Coastguard Worker            size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
3563*da0073e9SAndroid Build Coastguard Worker            return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
3564*da0073e9SAndroid Build Coastguard Worker
3565*da0073e9SAndroid Build Coastguard Worker        def ref_index_select(src, dim, idx):
3566*da0073e9SAndroid Build Coastguard Worker            # bfloat16 is just used on GPU, so it's not supported on numpy
3567*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
3568*da0073e9SAndroid Build Coastguard Worker                src = src.float()
3569*da0073e9SAndroid Build Coastguard Worker            out = torch.from_numpy(np.take(src.cpu().numpy(), idx.cpu().numpy(), axis=dim))
3570*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
3571*da0073e9SAndroid Build Coastguard Worker                out = out.to(device=device, dtype=dtype)
3572*da0073e9SAndroid Build Coastguard Worker            return out
3573*da0073e9SAndroid Build Coastguard Worker
3574*da0073e9SAndroid Build Coastguard Worker        for src_contig, idx_contig in product([True, False], repeat=2):
3575*da0073e9SAndroid Build Coastguard Worker            for other_sizes in ((), (4, 5)):
3576*da0073e9SAndroid Build Coastguard Worker                for dim in range(len(other_sizes)):
3577*da0073e9SAndroid Build Coastguard Worker                    src = make_arg(other_sizes, num_src, dim, src_contig)
3578*da0073e9SAndroid Build Coastguard Worker                    idx = make_tensor(
3579*da0073e9SAndroid Build Coastguard Worker                        (num_out,), dtype=torch.int64, device=device, low=0, high=num_src, noncontiguous=not idx_contig
3580*da0073e9SAndroid Build Coastguard Worker                    )
3581*da0073e9SAndroid Build Coastguard Worker                    out = torch.index_select(src, dim, idx)
3582*da0073e9SAndroid Build Coastguard Worker                    out2 = ref_index_select(src, dim, idx)
3583*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out, out2)
3584*da0073e9SAndroid Build Coastguard Worker
3585*da0073e9SAndroid Build Coastguard Worker        for idx_type in (torch.int32, torch.int64):
3586*da0073e9SAndroid Build Coastguard Worker            other_sizes = (3, 2)
3587*da0073e9SAndroid Build Coastguard Worker            dim = 1
3588*da0073e9SAndroid Build Coastguard Worker            src = make_arg(other_sizes, num_src, dim, True)
3589*da0073e9SAndroid Build Coastguard Worker            idx = make_tensor((num_out,), dtype=idx_type, device=device, low=0, high=num_src, noncontiguous=False)
3590*da0073e9SAndroid Build Coastguard Worker            out = torch.index_select(src, dim, idx)
3591*da0073e9SAndroid Build Coastguard Worker            out2 = ref_index_select(src, dim, idx)
3592*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2)
3593*da0073e9SAndroid Build Coastguard Worker
3594*da0073e9SAndroid Build Coastguard Worker        # Create the 4 possible combinations of scalar sizes for index / source
3595*da0073e9SAndroid Build Coastguard Worker        scalars = ((make_tensor(size_s, dtype=dtype, device=device),
3596*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(size_i, dtype=torch.int64, device=device))
3597*da0073e9SAndroid Build Coastguard Worker                   for size_s, size_i in product([(), (1,)], repeat=2))
3598*da0073e9SAndroid Build Coastguard Worker        for source, idx in scalars:
3599*da0073e9SAndroid Build Coastguard Worker            out = source.index_select(0, idx)
3600*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.item(), source.item())
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the take operator
3603*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3604*da0073e9SAndroid Build Coastguard Worker    def test_take(self, device, dtype):
3605*da0073e9SAndroid Build Coastguard Worker        idx_size = (4,)
3606*da0073e9SAndroid Build Coastguard Worker
3607*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
3608*da0073e9SAndroid Build Coastguard Worker        make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)
3609*da0073e9SAndroid Build Coastguard Worker
3610*da0073e9SAndroid Build Coastguard Worker        def ref_take(src, idx):
3611*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
3612*da0073e9SAndroid Build Coastguard Worker                src = src.half()
3613*da0073e9SAndroid Build Coastguard Worker            src = src.cpu().numpy()
3614*da0073e9SAndroid Build Coastguard Worker            idx = idx.cpu().numpy()
3615*da0073e9SAndroid Build Coastguard Worker            out = torch.from_numpy(np.take(src, idx)).to(device=device, dtype=dtype)
3616*da0073e9SAndroid Build Coastguard Worker            return out
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker        for src_contig, idx_contig, idx_reshape in product([True, False], repeat=3):
3619*da0073e9SAndroid Build Coastguard Worker            for src_size in ((5,), (4, 5)):
3620*da0073e9SAndroid Build Coastguard Worker                src = make_arg(src_size, noncontiguous=not src_contig)
3621*da0073e9SAndroid Build Coastguard Worker                idx = make_idx(idx_size, high=src.numel(), noncontiguous=not idx_contig)
3622*da0073e9SAndroid Build Coastguard Worker                if idx_reshape:
3623*da0073e9SAndroid Build Coastguard Worker                    idx = idx.reshape(2, 2)
3624*da0073e9SAndroid Build Coastguard Worker                out = torch.take(src, idx)
3625*da0073e9SAndroid Build Coastguard Worker                out2 = ref_take(src, idx)
3626*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, out2)
3627*da0073e9SAndroid Build Coastguard Worker
3628*da0073e9SAndroid Build Coastguard Worker        # Create the 4 possible combinations of scalar sizes for source / index
3629*da0073e9SAndroid Build Coastguard Worker        for size_s, size_i in product([(), (1,)], repeat=2):
3630*da0073e9SAndroid Build Coastguard Worker            source = make_arg(size_s)
3631*da0073e9SAndroid Build Coastguard Worker            idx = make_idx(size_i, high=1)
3632*da0073e9SAndroid Build Coastguard Worker            out = source.take(idx)
3633*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.item(), source.item())
3634*da0073e9SAndroid Build Coastguard Worker
3635*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the put operator
3636*da0073e9SAndroid Build Coastguard Worker    # The bool instance does not work on GPU. See
3637*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/54317
3638*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3639*da0073e9SAndroid Build Coastguard Worker    def test_put(self, device, dtype):
3640*da0073e9SAndroid Build Coastguard Worker        src_size = (4,)
3641*da0073e9SAndroid Build Coastguard Worker
3642*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
3643*da0073e9SAndroid Build Coastguard Worker        make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)
3644*da0073e9SAndroid Build Coastguard Worker
3645*da0073e9SAndroid Build Coastguard Worker        def ref_put(dst, idx, src, accumulate):
3646*da0073e9SAndroid Build Coastguard Worker            new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1)
3647*da0073e9SAndroid Build Coastguard Worker            new_idx = idx.contiguous().view(-1)
3648*da0073e9SAndroid Build Coastguard Worker            new_src = src.contiguous().view(-1)
3649*da0073e9SAndroid Build Coastguard Worker            method = new_dst.index_add_ if accumulate else new_dst.index_copy_
3650*da0073e9SAndroid Build Coastguard Worker            return method(0, new_idx, new_src).view_as(dst)
3651*da0073e9SAndroid Build Coastguard Worker
3652*da0073e9SAndroid Build Coastguard Worker        for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product([True, False], repeat=5):
3653*da0073e9SAndroid Build Coastguard Worker            for dst_size in ((5,), (4, 5)):
3654*da0073e9SAndroid Build Coastguard Worker                dst = make_arg(dst_size, noncontiguous=not dst_contig)
3655*da0073e9SAndroid Build Coastguard Worker                src = make_arg(src_size, noncontiguous=not src_contig)
3656*da0073e9SAndroid Build Coastguard Worker
3657*da0073e9SAndroid Build Coastguard Worker                # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU
3658*da0073e9SAndroid Build Coastguard Worker                # On CUDA it may not be, but the test has enough tolerance to account for this
3659*da0073e9SAndroid Build Coastguard Worker                if accumulate:
3660*da0073e9SAndroid Build Coastguard Worker                    idx = make_idx(src_size, high=dst.numel())
3661*da0073e9SAndroid Build Coastguard Worker                else:
3662*da0073e9SAndroid Build Coastguard Worker                    idx = torch.randperm(dst.numel(), dtype=torch.int64, device=device)[:src_size[0]]
3663*da0073e9SAndroid Build Coastguard Worker                if not idx_contig:
3664*da0073e9SAndroid Build Coastguard Worker                    idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2]
3665*da0073e9SAndroid Build Coastguard Worker                if idx_reshape:
3666*da0073e9SAndroid Build Coastguard Worker                    idx = idx.reshape(2, 2)
3667*da0073e9SAndroid Build Coastguard Worker                out = torch.put(dst, idx, src, accumulate)
3668*da0073e9SAndroid Build Coastguard Worker                # out-place
3669*da0073e9SAndroid Build Coastguard Worker                reference = ref_put(dst, idx, src, accumulate)
3670*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, reference)
3671*da0073e9SAndroid Build Coastguard Worker
3672*da0073e9SAndroid Build Coastguard Worker                # in-place
3673*da0073e9SAndroid Build Coastguard Worker                dst.put_(idx, src, accumulate)
3674*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dst, reference)
3675*da0073e9SAndroid Build Coastguard Worker
3676*da0073e9SAndroid Build Coastguard Worker
3677*da0073e9SAndroid Build Coastguard Worker        # Create the 8 possible combinations of scalar sizes for target / index / source
3678*da0073e9SAndroid Build Coastguard Worker        scalars = ((make_arg(size_t),
3679*da0073e9SAndroid Build Coastguard Worker                    make_idx(size_i, high=1),
3680*da0073e9SAndroid Build Coastguard Worker                    make_arg(size_s))
3681*da0073e9SAndroid Build Coastguard Worker                   for size_t, size_i, size_s in product([(), (1,)], repeat=3))
3682*da0073e9SAndroid Build Coastguard Worker        for (dest, idx, source), accumulate in product(scalars, [True, False]):
3683*da0073e9SAndroid Build Coastguard Worker            dest_init = dest.clone()
3684*da0073e9SAndroid Build Coastguard Worker            # out-place
3685*da0073e9SAndroid Build Coastguard Worker            out = torch.put(dest, idx, source, accumulate=accumulate)
3686*da0073e9SAndroid Build Coastguard Worker            # in-place
3687*da0073e9SAndroid Build Coastguard Worker            dest1 = dest.clone()
3688*da0073e9SAndroid Build Coastguard Worker            dest1.put_(idx, source, accumulate=accumulate)
3689*da0073e9SAndroid Build Coastguard Worker            for d in [out, dest1]:
3690*da0073e9SAndroid Build Coastguard Worker                if accumulate:
3691*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(d.item(), (dest_init + source).item())
3692*da0073e9SAndroid Build Coastguard Worker                else:
3693*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(d.item(), source.item())
3694*da0073e9SAndroid Build Coastguard Worker
3695*da0073e9SAndroid Build Coastguard Worker        # Empty case
3696*da0073e9SAndroid Build Coastguard Worker        dest = make_arg((3, 2))
3697*da0073e9SAndroid Build Coastguard Worker        reference = dest.clone()
3698*da0073e9SAndroid Build Coastguard Worker        idx = make_idx((0,), high=1)
3699*da0073e9SAndroid Build Coastguard Worker        source = make_arg((0,))
3700*da0073e9SAndroid Build Coastguard Worker        for accumulate in [True, False]:
3701*da0073e9SAndroid Build Coastguard Worker            out = torch.put(dest, idx, source, accumulate=accumulate)
3702*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, reference)
3703*da0073e9SAndroid Build Coastguard Worker            dest.put_(idx, source, accumulate=accumulate)
3704*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dest, reference)
3705*da0073e9SAndroid Build Coastguard Worker
3706*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the put operator
3707*da0073e9SAndroid Build Coastguard Worker    # The bool instance does not work on GPU. See
3708*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/54317
3709*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3710*da0073e9SAndroid Build Coastguard Worker    def test_put_accumulate(self, device, dtype):
3711*da0073e9SAndroid Build Coastguard Worker        # Test for parallel adds with accumulate == True
3712*da0073e9SAndroid Build Coastguard Worker        low_precision = dtype == torch.half or dtype == torch.bfloat16
3713*da0073e9SAndroid Build Coastguard Worker        # Less numbers to avoid overflow with low_precision
3714*da0073e9SAndroid Build Coastguard Worker        # Grainsize is 3000 for the for_loop to be parallized on CPU
3715*da0073e9SAndroid Build Coastguard Worker        sizes = ((100,)) if low_precision else ((200,), (3002,))
3716*da0073e9SAndroid Build Coastguard Worker        # Bfloat16 has a particularly bad performance here
3717*da0073e9SAndroid Build Coastguard Worker        # This operation is nondeterministic on GPU, so we are generous with the rtol
3718*da0073e9SAndroid Build Coastguard Worker        rtol, atol = (1e-1, 1e-2) if low_precision else (1e-3, 1e-4)
3719*da0073e9SAndroid Build Coastguard Worker
3720*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, low=-2, high=3, device=device, dtype=dtype)
3721*da0073e9SAndroid Build Coastguard Worker        # Dump everything into the 0-th position
3722*da0073e9SAndroid Build Coastguard Worker        make_idx = partial(torch.zeros, device=device, dtype=torch.int64)
3723*da0073e9SAndroid Build Coastguard Worker        args = ((make_idx(size), make_arg(size)) for size in sizes)
3724*da0073e9SAndroid Build Coastguard Worker
3725*da0073e9SAndroid Build Coastguard Worker        for idx, source in args:
3726*da0073e9SAndroid Build Coastguard Worker            orig = make_arg((1,))
3727*da0073e9SAndroid Build Coastguard Worker            out = orig.put(idx, source, accumulate=True)
3728*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, orig + source.sum(), rtol=rtol, atol=atol)
3729*da0073e9SAndroid Build Coastguard Worker
3730*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the take operator
3731*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
3732*da0073e9SAndroid Build Coastguard Worker    def test_take_empty(self, device):
3733*da0073e9SAndroid Build Coastguard Worker        for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
3734*da0073e9SAndroid Build Coastguard Worker            for indices_shape in [(0,), (0, 1, 2, 0)]:
3735*da0073e9SAndroid Build Coastguard Worker                input = torch.empty(input_shape, device=device)
3736*da0073e9SAndroid Build Coastguard Worker                indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
3737*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(indices, torch.take(input, indices), exact_dtype=False)
3738*da0073e9SAndroid Build Coastguard Worker
3739*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the put operator
3740*da0073e9SAndroid Build Coastguard Worker    def test_put_empty(self, device):
3741*da0073e9SAndroid Build Coastguard Worker        for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
3742*da0073e9SAndroid Build Coastguard Worker            for indices_shape in [(0,), (0, 1, 2, 0)]:
3743*da0073e9SAndroid Build Coastguard Worker                for accumulate in [False, True]:
3744*da0073e9SAndroid Build Coastguard Worker                    dst = torch.randn(dst_shape, device=device)
3745*da0073e9SAndroid Build Coastguard Worker                    indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
3746*da0073e9SAndroid Build Coastguard Worker                    src = torch.randn(indices_shape, device=device)
3747*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate))
3748*da0073e9SAndroid Build Coastguard Worker
3749*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to test_scatter_gather_ops.py
3750*da0073e9SAndroid Build Coastguard Worker    def scatter_allow_reduce(self, device, dtype, reduceop):
3751*da0073e9SAndroid Build Coastguard Worker        device_type = torch.device(device).type
3752*da0073e9SAndroid Build Coastguard Worker        return device_type != 'cuda' or (reduceop == 'multiply' and dtype.is_floating_point)
3753*da0073e9SAndroid Build Coastguard Worker
3754*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3755*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3756*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3757*da0073e9SAndroid Build Coastguard Worker    def test_scatter_reduce_operations_to_large_input(self, device, dtype):
3758*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3759*da0073e9SAndroid Build Coastguard Worker        test_data = [
3760*da0073e9SAndroid Build Coastguard Worker            (torch.zeros(4, 4, device=device, dtype=dtype),
3761*da0073e9SAndroid Build Coastguard Worker             torch.ones(2, 2, device=device, dtype=dtype),
3762*da0073e9SAndroid Build Coastguard Worker             torch.tensor([[0, 0, 0, 0],
3763*da0073e9SAndroid Build Coastguard Worker                           [1, 0, 0, 0],
3764*da0073e9SAndroid Build Coastguard Worker                           [1, 0, 0, 0],
3765*da0073e9SAndroid Build Coastguard Worker                           [0, 0, 0, 0]],
3766*da0073e9SAndroid Build Coastguard Worker                          device=device, dtype=dtype), "add"),
3767*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4),
3768*da0073e9SAndroid Build Coastguard Worker             torch.tensor([6], device=device, dtype=dtype).repeat(2, 2),
3769*da0073e9SAndroid Build Coastguard Worker             torch.tensor([[2, 2, 2, 2],
3770*da0073e9SAndroid Build Coastguard Worker                           [12, 2, 2, 2],
3771*da0073e9SAndroid Build Coastguard Worker                           [12, 2, 2, 2],
3772*da0073e9SAndroid Build Coastguard Worker                           [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"),
3773*da0073e9SAndroid Build Coastguard Worker        ]
3774*da0073e9SAndroid Build Coastguard Worker
3775*da0073e9SAndroid Build Coastguard Worker        for input, src, result, operation in test_data:
3776*da0073e9SAndroid Build Coastguard Worker            if not self.scatter_allow_reduce(device, dtype, operation):
3777*da0073e9SAndroid Build Coastguard Worker                continue
3778*da0073e9SAndroid Build Coastguard Worker            input.scatter_(0, index, src, reduce=operation)
3779*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, result)
3780*da0073e9SAndroid Build Coastguard Worker
3781*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3782*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3783*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3784*da0073e9SAndroid Build Coastguard Worker    def test_scatter_reduce_scalar(self, device, dtype):
3785*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3786*da0073e9SAndroid Build Coastguard Worker        test_data = [
3787*da0073e9SAndroid Build Coastguard Worker            (torch.zeros(4, 4, device=device, dtype=dtype), 1,
3788*da0073e9SAndroid Build Coastguard Worker             torch.tensor([[0, 0, 0, 0],
3789*da0073e9SAndroid Build Coastguard Worker                           [1, 0, 0, 0],
3790*da0073e9SAndroid Build Coastguard Worker                           [1, 0, 0, 0],
3791*da0073e9SAndroid Build Coastguard Worker                           [0, 0, 0, 0]],
3792*da0073e9SAndroid Build Coastguard Worker                          device=device, dtype=dtype), "add"),
3793*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4), 2,
3794*da0073e9SAndroid Build Coastguard Worker             torch.tensor([[2, 2, 2, 2],
3795*da0073e9SAndroid Build Coastguard Worker                           [4, 2, 2, 2],
3796*da0073e9SAndroid Build Coastguard Worker                           [4, 2, 2, 2],
3797*da0073e9SAndroid Build Coastguard Worker                           [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"),
3798*da0073e9SAndroid Build Coastguard Worker        ]
3799*da0073e9SAndroid Build Coastguard Worker
3800*da0073e9SAndroid Build Coastguard Worker        for input, src, result, operation in test_data:
3801*da0073e9SAndroid Build Coastguard Worker            if not self.scatter_allow_reduce(device, dtype, operation):
3802*da0073e9SAndroid Build Coastguard Worker                continue
3803*da0073e9SAndroid Build Coastguard Worker            input.scatter_(0, index, src, reduce=operation)
3804*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, result)
3805*da0073e9SAndroid Build Coastguard Worker
3806*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to test_scatter_gather_ops.py
3807*da0073e9SAndroid Build Coastguard Worker    # TODO: remove this after scatter_add_ is deprecated.
3808*da0073e9SAndroid Build Coastguard Worker    def test_scatter_add_non_unique_index(self, device):
3809*da0073e9SAndroid Build Coastguard Worker        height = 2
3810*da0073e9SAndroid Build Coastguard Worker        width = 65536
3811*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(height, width, device=device)
3812*da0073e9SAndroid Build Coastguard Worker        index = torch.zeros(height, width, dtype=torch.long, device=device)
3813*da0073e9SAndroid Build Coastguard Worker        src = torch.ones(height, width, device=device)
3814*da0073e9SAndroid Build Coastguard Worker        input.scatter_add_(0, index, src)
3815*da0073e9SAndroid Build Coastguard Worker
3816*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input,
3817*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([[3], [1]], device=device,
3818*da0073e9SAndroid Build Coastguard Worker                                      dtype=torch.float32).repeat(1, width))
3819*da0073e9SAndroid Build Coastguard Worker
3820*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3821*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3822*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3823*da0073e9SAndroid Build Coastguard Worker    def test_scatter_reduce_non_unique_index(self, device, dtype):
3824*da0073e9SAndroid Build Coastguard Worker        height = 2
3825*da0073e9SAndroid Build Coastguard Worker        width = 2
3826*da0073e9SAndroid Build Coastguard Worker        index = torch.zeros(height, width, dtype=torch.long, device=device)
3827*da0073e9SAndroid Build Coastguard Worker        test_data = [
3828*da0073e9SAndroid Build Coastguard Worker            (torch.ones(height, width, device=device, dtype=dtype),
3829*da0073e9SAndroid Build Coastguard Worker             torch.ones(height, width, device=device, dtype=dtype),
3830*da0073e9SAndroid Build Coastguard Worker             torch.tensor([[3], [1]], device=device, dtype=dtype).repeat(1, width), "add"),
3831*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([2], device=device, dtype=dtype).repeat(height, width),
3832*da0073e9SAndroid Build Coastguard Worker             torch.tensor([2], device=device, dtype=dtype).repeat(height, width),
3833*da0073e9SAndroid Build Coastguard Worker             torch.tensor([[8], [2]], device=device,
3834*da0073e9SAndroid Build Coastguard Worker                          dtype=dtype).repeat(1, width), "multiply"),
3835*da0073e9SAndroid Build Coastguard Worker        ]
3836*da0073e9SAndroid Build Coastguard Worker
3837*da0073e9SAndroid Build Coastguard Worker        for input, src, result, operation in test_data:
3838*da0073e9SAndroid Build Coastguard Worker            if not self.scatter_allow_reduce(device, dtype, operation):
3839*da0073e9SAndroid Build Coastguard Worker                continue
3840*da0073e9SAndroid Build Coastguard Worker            input.scatter_(0, index, src, reduce=operation)
3841*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")
3842*da0073e9SAndroid Build Coastguard Worker
3843*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3844*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
3845*da0073e9SAndroid Build Coastguard Worker    def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
3846*da0073e9SAndroid Build Coastguard Worker        height = 2
3847*da0073e9SAndroid Build Coastguard Worker        width = 2
3848*da0073e9SAndroid Build Coastguard Worker        index = torch.zeros(height, width, dtype=torch.long, device=device)
3849*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(height, width, device=device, dtype=dtype)
3850*da0073e9SAndroid Build Coastguard Worker        src = torch.ones(height, width, device=device, dtype=dtype)
3851*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3852*da0073e9SAndroid Build Coastguard Worker            input.scatter_(0, index, src, reduce="multiply")
3853*da0073e9SAndroid Build Coastguard Worker
3854*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to test_scatter_gather_ops.py
3855*da0073e9SAndroid Build Coastguard Worker    def test_scatter_to_large_input(self, device):
3856*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros(4, 4, device=device)
3857*da0073e9SAndroid Build Coastguard Worker        src = torch.ones(2, 2, device=device)
3858*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3859*da0073e9SAndroid Build Coastguard Worker        input.scatter_(0, index, src)
3860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
3861*da0073e9SAndroid Build Coastguard Worker                                              [1, 0, 0, 0],
3862*da0073e9SAndroid Build Coastguard Worker                                              [1, 0, 0, 0],
3863*da0073e9SAndroid Build Coastguard Worker                                              [0, 0, 0, 0]], device=device, dtype=torch.float32))
3864*da0073e9SAndroid Build Coastguard Worker
3865*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to test_scatter_gather_ops.py
3866*da0073e9SAndroid Build Coastguard Worker    def test_scatter_add_to_large_input(self, device):
3867*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros(4, 4, device=device)
3868*da0073e9SAndroid Build Coastguard Worker        src = torch.ones(2, 2, device=device)
3869*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3870*da0073e9SAndroid Build Coastguard Worker        input.scatter_add_(0, index, src)
3871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
3872*da0073e9SAndroid Build Coastguard Worker                                              [1, 0, 0, 0],
3873*da0073e9SAndroid Build Coastguard Worker                                              [1, 0, 0, 0],
3874*da0073e9SAndroid Build Coastguard Worker                                              [0, 0, 0, 0]], device=device, dtype=torch.float32))
3875*da0073e9SAndroid Build Coastguard Worker
3876*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to test_scatter_gather_ops.py
3877*da0073e9SAndroid Build Coastguard Worker    def test_scatter_bool(self, device):
3878*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[True, True, True], [True, True, True]], device=device)
3879*da0073e9SAndroid Build Coastguard Worker        res = torch.zeros(3, 3, dtype=torch.bool, device=device)
3880*da0073e9SAndroid Build Coastguard Worker        res = res.scatter_(0, torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), x)
3881*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, torch.tensor([[True, False, False],
3882*da0073e9SAndroid Build Coastguard Worker                                            [False, True, False],
3883*da0073e9SAndroid Build Coastguard Worker                                            [False, False, True]], device=device))
3884*da0073e9SAndroid Build Coastguard Worker
3885*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to test_scatter_gather_ops.py
3886*da0073e9SAndroid Build Coastguard Worker    def test_scatter_add_bool(self, device):
3887*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]], device=device)
3888*da0073e9SAndroid Build Coastguard Worker        res = torch.zeros(3, 5, dtype=torch.bool, device=device)
3889*da0073e9SAndroid Build Coastguard Worker        res = res.scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], device=device), x)
3890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, torch.tensor([[True, True, True, True, True],
3891*da0073e9SAndroid Build Coastguard Worker                                            [False, True, False, True, False],
3892*da0073e9SAndroid Build Coastguard Worker                                            [True, False, True, False, True]], device=device))
3893*da0073e9SAndroid Build Coastguard Worker
3894*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked scatter operator
3895*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3896*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
3897*da0073e9SAndroid Build Coastguard Worker    def test_masked_scatter(self, device, dtype):
3898*da0073e9SAndroid Build Coastguard Worker        dt = dtype
3899*da0073e9SAndroid Build Coastguard Worker        num_copy, num_dest = 3, 10
3900*da0073e9SAndroid Build Coastguard Worker        dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device)
3901*da0073e9SAndroid Build Coastguard Worker        dest2 = dest.clone()
3902*da0073e9SAndroid Build Coastguard Worker        dest_ones = dest.clone()
3903*da0073e9SAndroid Build Coastguard Worker        dest_ones_expected = dest.clone()
3904*da0073e9SAndroid Build Coastguard Worker        src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device)
3905*da0073e9SAndroid Build Coastguard Worker        src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device)
3906*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=torch.bool, device=device)
3907*da0073e9SAndroid Build Coastguard Worker
3908*da0073e9SAndroid Build Coastguard Worker        dest.masked_scatter_(mask, src)
3909*da0073e9SAndroid Build Coastguard Worker        j = 0
3910*da0073e9SAndroid Build Coastguard Worker        for i in range(num_dest):
3911*da0073e9SAndroid Build Coastguard Worker            if mask[i]:
3912*da0073e9SAndroid Build Coastguard Worker                dest2[i] = src[j]
3913*da0073e9SAndroid Build Coastguard Worker                dest_ones_expected[i] = src_ones[j]
3914*da0073e9SAndroid Build Coastguard Worker                j += 1
3915*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dest, dest2, atol=0, rtol=0)
3916*da0073e9SAndroid Build Coastguard Worker
3917*da0073e9SAndroid Build Coastguard Worker        dest_ones.masked_scatter_(mask, src_ones)
3918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0)
3919*da0073e9SAndroid Build Coastguard Worker
3920*da0073e9SAndroid Build Coastguard Worker        # Bound checking in CUDA is done inside a kernel
3921*da0073e9SAndroid Build Coastguard Worker        # in order to avoid synchronization, but this means
3922*da0073e9SAndroid Build Coastguard Worker        # we can not clear the failures. So there is no way
3923*da0073e9SAndroid Build Coastguard Worker        # to test it then recover.
3924*da0073e9SAndroid Build Coastguard Worker        if self.device_type != 'cuda':
3925*da0073e9SAndroid Build Coastguard Worker            # make src smaller. this should fail
3926*da0073e9SAndroid Build Coastguard Worker            src = torch.zeros(num_copy - 1, dtype=dt, device=device)
3927*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
3928*da0073e9SAndroid Build Coastguard Worker                dest.masked_scatter_(mask, src)
3929*da0073e9SAndroid Build Coastguard Worker
3930*da0073e9SAndroid Build Coastguard Worker        # empty tensor
3931*da0073e9SAndroid Build Coastguard Worker        dest = torch.empty((5, 0, 5), dtype=dt, device=device)
3932*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones_like(dest, dtype=torch.bool, device=device)
3933*da0073e9SAndroid Build Coastguard Worker        src = torch.empty((0,), dtype=dt, device=device)
3934*da0073e9SAndroid Build Coastguard Worker        dest.masked_scatter_(mask, src)
3935*da0073e9SAndroid Build Coastguard Worker
3936*da0073e9SAndroid Build Coastguard Worker        dest = torch.empty((5, 0, 5), dtype=dt, device=device)
3937*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones((5, 1, 5), dtype=torch.bool, device=device)
3938*da0073e9SAndroid Build Coastguard Worker        src = torch.empty((0,), dtype=dt, device=device)
3939*da0073e9SAndroid Build Coastguard Worker        dest.masked_scatter_(mask, src)
3940*da0073e9SAndroid Build Coastguard Worker
3941*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked scatter operator
3942*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
3943*da0073e9SAndroid Build Coastguard Worker    def test_masked_scatter_bool_tensor(self, device):
3944*da0073e9SAndroid Build Coastguard Worker        src = torch.tensor([True, True, True], device=device)
3945*da0073e9SAndroid Build Coastguard Worker        dst = torch.tensor([False, False, False], device=device)
3946*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([False, True, False], device=device)
3947*da0073e9SAndroid Build Coastguard Worker
3948*da0073e9SAndroid Build Coastguard Worker        dst.masked_scatter_(mask, src)
3949*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, torch.tensor([False, True, False], device=device))
3950*da0073e9SAndroid Build Coastguard Worker
3951*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([True, False, True], device=device)
3952*da0073e9SAndroid Build Coastguard Worker        dst = dst.masked_scatter(mask, src)
3953*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, torch.tensor([True, True, True], device=device))
3954*da0073e9SAndroid Build Coastguard Worker
3955*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked scatter operator
3956*da0073e9SAndroid Build Coastguard Worker    #   test_scatter_gather_ops or test_masked_ops?
3957*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3958*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('30GB')
3959*da0073e9SAndroid Build Coastguard Worker    def test_masked_scatter_large_tensor(self, device):
3960*da0073e9SAndroid Build Coastguard Worker        t_cpu = torch.empty(2**31 + 1, dtype=torch.bool).random_()
3961*da0073e9SAndroid Build Coastguard Worker        t = t_cpu.to(device)
3962*da0073e9SAndroid Build Coastguard Worker        result_cpu = t_cpu.masked_scatter(t_cpu, t_cpu)
3963*da0073e9SAndroid Build Coastguard Worker        result = t.masked_scatter(t, t)
3964*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, result_cpu)
3965*da0073e9SAndroid Build Coastguard Worker
3966*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked select operator
3967*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3968*da0073e9SAndroid Build Coastguard Worker    def test_masked_select(self, device, dtype):
3969*da0073e9SAndroid Build Coastguard Worker        if device == 'cpu':
3970*da0073e9SAndroid Build Coastguard Worker            warn = 'masked_select received a mask with dtype torch.uint8,'
3971*da0073e9SAndroid Build Coastguard Worker        else:
3972*da0073e9SAndroid Build Coastguard Worker            warn = 'indexing with dtype torch.uint8 is now deprecated, pl'
3973*da0073e9SAndroid Build Coastguard Worker        for maskType in integral_types_and(torch.bool):
3974*da0073e9SAndroid Build Coastguard Worker            num_src = 10
3975*da0073e9SAndroid Build Coastguard Worker            src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device)
3976*da0073e9SAndroid Build Coastguard Worker            mask = torch.randint(2, (num_src,), device=device, dtype=maskType)
3977*da0073e9SAndroid Build Coastguard Worker
3978*da0073e9SAndroid Build Coastguard Worker            if maskType is not torch.bool:
3979*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, r'expected BoolTensor for mask'):
3980*da0073e9SAndroid Build Coastguard Worker                    dst = src.masked_select(mask)
3981*da0073e9SAndroid Build Coastguard Worker                continue
3982*da0073e9SAndroid Build Coastguard Worker            else:
3983*da0073e9SAndroid Build Coastguard Worker                dst = src.masked_select(mask)
3984*da0073e9SAndroid Build Coastguard Worker            dst2 = []
3985*da0073e9SAndroid Build Coastguard Worker            for i in range(num_src):
3986*da0073e9SAndroid Build Coastguard Worker                if mask[i]:
3987*da0073e9SAndroid Build Coastguard Worker                    dst2 += [src[i]]
3988*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst, torch.tensor(dst2), atol=0, rtol=0)
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker            dst3 = torch.empty(0, device=device, dtype=dtype)
3991*da0073e9SAndroid Build Coastguard Worker            torch.masked_select(src, mask, out=dst3)
3992*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst3, torch.tensor(dst2, dtype=dst3.dtype), atol=0, rtol=0)
3993*da0073e9SAndroid Build Coastguard Worker
3994*da0073e9SAndroid Build Coastguard Worker        # Since half on CPU is not supported, need to skip the remaining test cases
3995*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half and torch.device(device).type == 'cpu':
3996*da0073e9SAndroid Build Coastguard Worker            return
3997*da0073e9SAndroid Build Coastguard Worker
3998*da0073e9SAndroid Build Coastguard Worker        # Ensure that masks are expanded to match tensor properly
3999*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(100, 100, device=device).mul(100).to(dtype)
4000*da0073e9SAndroid Build Coastguard Worker        mask_first_el_each_row = torch.zeros(100, device=device, dtype=torch.bool)
4001*da0073e9SAndroid Build Coastguard Worker        mask_first_el_each_row[0] = True
4002*da0073e9SAndroid Build Coastguard Worker        a_masked = a.masked_select(mask_first_el_each_row)
4003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_masked, a[:, 0])
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker        mask_first_row = torch.zeros(100, 1, device=device, dtype=torch.bool)
4006*da0073e9SAndroid Build Coastguard Worker        mask_first_row[0][0] = True
4007*da0073e9SAndroid Build Coastguard Worker        a_masked = a.masked_select(mask_first_row)
4008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_masked, a[0, :])
4009*da0073e9SAndroid Build Coastguard Worker
4010*da0073e9SAndroid Build Coastguard Worker        # Ensure that tensor is expanded to match mask properly
4011*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(100, device=device).mul(100).to(dtype)
4012*da0073e9SAndroid Build Coastguard Worker        mask_copy_3_times = torch.tensor([[True], [True], [False], [True]], device=device)
4013*da0073e9SAndroid Build Coastguard Worker        a_masked = a.masked_select(mask_copy_3_times)
4014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_masked, a.unsqueeze(0).expand(3, 100).flatten())
4015*da0073e9SAndroid Build Coastguard Worker
4016*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked select operator
4017*da0073e9SAndroid Build Coastguard Worker    def test_masked_select_discontiguous(self, device):
4018*da0073e9SAndroid Build Coastguard Worker        for size in (10, 200):
4019*da0073e9SAndroid Build Coastguard Worker            vals = torch.rand(size, size, device=device)
4020*da0073e9SAndroid Build Coastguard Worker            mask = torch.full((size, size), False, dtype=torch.bool, device=device)
4021*da0073e9SAndroid Build Coastguard Worker            mask[:, ::2] = True
4022*da0073e9SAndroid Build Coastguard Worker            vals_list = (vals, vals.t())
4023*da0073e9SAndroid Build Coastguard Worker            mask_list = (mask, mask.t())
4024*da0073e9SAndroid Build Coastguard Worker            out_dc = torch.empty(size * size, device=device)[::2]
4025*da0073e9SAndroid Build Coastguard Worker            for v, m in product(vals_list, mask_list):
4026*da0073e9SAndroid Build Coastguard Worker                if m.is_contiguous():
4027*da0073e9SAndroid Build Coastguard Worker                    expected = v[:, ::2].clone().reshape((-1, ))
4028*da0073e9SAndroid Build Coastguard Worker                else:
4029*da0073e9SAndroid Build Coastguard Worker                    expected = v[::2].clone().reshape((-1, ))
4030*da0073e9SAndroid Build Coastguard Worker                out = torch.masked_select(v, m)
4031*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, expected, atol=0, rtol=0)
4032*da0073e9SAndroid Build Coastguard Worker                torch.masked_select(v, m, out=out_dc)
4033*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_dc, expected, atol=0, rtol=0)
4034*da0073e9SAndroid Build Coastguard Worker
4035*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked fill operator
4036*da0073e9SAndroid Build Coastguard Worker    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16), (torch.uint8, torch.bool)))
4037*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill(self, device, dtypes):
4038*da0073e9SAndroid Build Coastguard Worker        dtype = dtypes[0]
4039*da0073e9SAndroid Build Coastguard Worker        mask_dtype = dtypes[1]
4040*da0073e9SAndroid Build Coastguard Worker
4041*da0073e9SAndroid Build Coastguard Worker        num_dest = 10
4042*da0073e9SAndroid Build Coastguard Worker        dst = torch.zeros(num_dest, dtype=dtype)
4043*da0073e9SAndroid Build Coastguard Worker        mask = torch.randint(2, (num_dest,), dtype=mask_dtype)
4044*da0073e9SAndroid Build Coastguard Worker        val = random.random()
4045*da0073e9SAndroid Build Coastguard Worker        dst2 = dst.clone()
4046*da0073e9SAndroid Build Coastguard Worker
4047*da0073e9SAndroid Build Coastguard Worker        if mask_dtype is not torch.bool:
4048*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'only supports boolean masks'):
4049*da0073e9SAndroid Build Coastguard Worker                dst.masked_fill_(mask, val)
4050*da0073e9SAndroid Build Coastguard Worker            return
4051*da0073e9SAndroid Build Coastguard Worker
4052*da0073e9SAndroid Build Coastguard Worker        dst.masked_fill_(mask, val)
4053*da0073e9SAndroid Build Coastguard Worker        for i in range(num_dest):
4054*da0073e9SAndroid Build Coastguard Worker            if mask[i]:
4055*da0073e9SAndroid Build Coastguard Worker                dst2[i] = val
4056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, dst2, atol=0, rtol=0)
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous case
4059*da0073e9SAndroid Build Coastguard Worker        dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
4060*da0073e9SAndroid Build Coastguard Worker        dst2 = dst.contiguous()
4061*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
4062*da0073e9SAndroid Build Coastguard Worker            mask = dst.abs() > 0
4063*da0073e9SAndroid Build Coastguard Worker        else:
4064*da0073e9SAndroid Build Coastguard Worker            mask = dst > 0
4065*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not dst.is_contiguous())
4066*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(dst2.is_contiguous())
4067*da0073e9SAndroid Build Coastguard Worker        dst.masked_fill_(mask.to(mask_dtype), val)
4068*da0073e9SAndroid Build Coastguard Worker        dst2.masked_fill_(mask.to(mask_dtype), val)
4069*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, dst2, atol=0, rtol=0)
4070*da0073e9SAndroid Build Coastguard Worker
4071*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the masked fill operator
4072*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill_bool_tensor(self, device):
4073*da0073e9SAndroid Build Coastguard Worker        dst = torch.tensor([True, False, True], device=device)
4074*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([False, True, False], device=device)
4075*da0073e9SAndroid Build Coastguard Worker
4076*da0073e9SAndroid Build Coastguard Worker        dst.masked_fill_(mask, True)
4077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, torch.tensor([True, True, True], device=device))
4078*da0073e9SAndroid Build Coastguard Worker
4079*da0073e9SAndroid Build Coastguard Worker        dst = dst.masked_fill(mask, False)
4080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst, torch.tensor([True, False, True], device=device))
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Worker    def test_tensor_shape_empty(self, device):
4083*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((0, 1, 3, 0), device=device)
4084*da0073e9SAndroid Build Coastguard Worker        # flatten
4085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), torch.flatten(x, 0, 3).shape)
4086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape)
4087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape)
4088*da0073e9SAndroid Build Coastguard Worker
4089*da0073e9SAndroid Build Coastguard Worker        # squeeze, unsqueeze
4090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape)
4091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape)
4092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 3, 0), torch.squeeze(x).shape)
4093*da0073e9SAndroid Build Coastguard Worker
4094*da0073e9SAndroid Build Coastguard Worker        # transpose, t
4095*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape)
4096*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((5, 0), device=device)
4097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5), y.t().shape)
4098*da0073e9SAndroid Build Coastguard Worker
4099*da0073e9SAndroid Build Coastguard Worker        # select
4100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape)
4101*da0073e9SAndroid Build Coastguard Worker
4102*da0073e9SAndroid Build Coastguard Worker        # repeat, permute
4103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape)
4104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape)
4105*da0073e9SAndroid Build Coastguard Worker
4106*da0073e9SAndroid Build Coastguard Worker        # diagonal, diagflat
4107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape)
4108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape)
4109*da0073e9SAndroid Build Coastguard Worker        # off the end offsets are valid
4110*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape)
4111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape)
4112*da0073e9SAndroid Build Coastguard Worker        # check non-zero sized offsets off the end
4113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape)
4114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape)
4115*da0073e9SAndroid Build Coastguard Worker
4116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), torch.diagflat(torch.tensor([], device=device)).shape)
4117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([], device=device), offset=1))
4118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), torch.diagflat(torch.tensor([[]], device=device)).shape)
4119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([[]], device=device), offset=1))
4120*da0073e9SAndroid Build Coastguard Worker
4121*da0073e9SAndroid Build Coastguard Worker        # stack, split, chunk
4122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape)
4123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 3, 0)],
4124*da0073e9SAndroid Build Coastguard Worker                         [z.shape for z in torch.chunk(x, 1, dim=0)])
4125*da0073e9SAndroid Build Coastguard Worker
4126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=0)])
4127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=2)])
4128*da0073e9SAndroid Build Coastguard Worker
4129*da0073e9SAndroid Build Coastguard Worker        # NOTE: split_with_sizes behaves differently than NumPy in that it
4130*da0073e9SAndroid Build Coastguard Worker        # takes sizes rather than offsets
4131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)],
4132*da0073e9SAndroid Build Coastguard Worker                         [z.shape for z in torch.split(x, (0, 1, 2), dim=2)])
4133*da0073e9SAndroid Build Coastguard Worker
4134*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1))
4135*da0073e9SAndroid Build Coastguard Worker        # This is strange because the split size is larger than the dim size, but consistent with
4136*da0073e9SAndroid Build Coastguard Worker        # how split handles that case generally (when no 0s are involved).
4137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
4138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
4139*da0073e9SAndroid Build Coastguard Worker
4140*da0073e9SAndroid Build Coastguard Worker    # functions that operate over a dimension but don't reduce.
4141*da0073e9SAndroid Build Coastguard Worker    def test_dim_function_empty(self, device):
4142*da0073e9SAndroid Build Coastguard Worker        shape = (0, 1, 2, 0)
4143*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
4144*da0073e9SAndroid Build Coastguard Worker
4145*da0073e9SAndroid Build Coastguard Worker        # size stride
4146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, x.size(3))
4147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, x.size(2))
4148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, x.stride(0))
4149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, x.stride(2))
4150*da0073e9SAndroid Build Coastguard Worker
4151*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.glu(x, 0))
4152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 1, 1, 0), torch.nn.functional.glu(x, 2).shape)
4153*da0073e9SAndroid Build Coastguard Worker
4154*da0073e9SAndroid Build Coastguard Worker        # softmax, logsoftmax
4155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.softmax(x, 0))
4156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.softmax(x, 2))
4157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.softmax(x, 3))
4158*da0073e9SAndroid Build Coastguard Worker
4159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.log_softmax(x, 0))
4160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.log_softmax(x, 2))
4161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.nn.functional.log_softmax(x, 3))
4162*da0073e9SAndroid Build Coastguard Worker
4163*da0073e9SAndroid Build Coastguard Worker        # cumsum, cumprod, cummax, cummin
4164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cumsum(x, 0).shape)
4165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cumsum(x, 2).shape)
4166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cumprod(x, 0).shape)
4167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cumprod(x, 2).shape)
4168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cummax(x, 0)[0].shape)
4169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cummax(x, 2)[0].shape)
4170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cummin(x, 0)[0].shape)
4171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.cummin(x, 2)[0].shape)
4172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.logcumsumexp(x, 0).shape)
4173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.logcumsumexp(x, 2).shape)
4174*da0073e9SAndroid Build Coastguard Worker
4175*da0073e9SAndroid Build Coastguard Worker        # flip
4176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.flip(0))
4177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.flip(2))
4178*da0073e9SAndroid Build Coastguard Worker
4179*da0073e9SAndroid Build Coastguard Worker        # roll
4180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.roll(0, 1).roll(0, -1))
4181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.roll(1, x.size(1)))
4182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.roll(1))
4183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.roll((1, 1), (3, 1)))
4184*da0073e9SAndroid Build Coastguard Worker
4185*da0073e9SAndroid Build Coastguard Worker        # unbind
4186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((), x.unbind(0))
4187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)),
4188*da0073e9SAndroid Build Coastguard Worker                         x.unbind(2))
4189*da0073e9SAndroid Build Coastguard Worker
4190*da0073e9SAndroid Build Coastguard Worker        # cross
4191*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((0, 1, 3, 0), device=device)
4192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.shape, torch.cross(y, y).shape)
4193*da0073e9SAndroid Build Coastguard Worker
4194*da0073e9SAndroid Build Coastguard Worker        # renorm
4195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape)
4196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape)
4197*da0073e9SAndroid Build Coastguard Worker
4198*da0073e9SAndroid Build Coastguard Worker        # sort
4199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=0)])
4200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=2)])
4201*da0073e9SAndroid Build Coastguard Worker
4202*da0073e9SAndroid Build Coastguard Worker        # topk
4203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([shape, shape], [z.shape for z in torch.topk(x, 0, dim=0)])
4204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape for z in torch.topk(x, 1, dim=2)])
4205*da0073e9SAndroid Build Coastguard Worker
4206*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((2, 3, 4), device=device)
4207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)])
4208*da0073e9SAndroid Build Coastguard Worker
4209*da0073e9SAndroid Build Coastguard Worker        # gather
4210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape)
4211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape)
4212*da0073e9SAndroid Build Coastguard Worker        larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device)
4213*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape)
4214*da0073e9SAndroid Build Coastguard Worker        smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device)
4215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape)
4216*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((2, 3, 4), device=device)
4217*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 3, 4),
4218*da0073e9SAndroid Build Coastguard Worker                         torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape)
4219*da0073e9SAndroid Build Coastguard Worker
4220*da0073e9SAndroid Build Coastguard Worker        # scatter, scatter_add
4221*da0073e9SAndroid Build Coastguard Worker        for dim in [0, 2]:
4222*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(shape, device=device)
4223*da0073e9SAndroid Build Coastguard Worker            y_src = torch.randn(shape, device=device)
4224*da0073e9SAndroid Build Coastguard Worker            ind = torch.empty(shape, dtype=torch.int64, device=device)
4225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape)
4226*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape)
4227*da0073e9SAndroid Build Coastguard Worker
4228*da0073e9SAndroid Build Coastguard Worker        z = torch.randn((2, 3, 4), device=device)
4229*da0073e9SAndroid Build Coastguard Worker        z_src = torch.randn((2, 3, 4), device=device)
4230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
4231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
4232*da0073e9SAndroid Build Coastguard Worker
4233*da0073e9SAndroid Build Coastguard Worker        # index_fill, index_copy, index_add
4234*da0073e9SAndroid Build Coastguard Worker        c = x.clone()
4235*da0073e9SAndroid Build Coastguard Worker        c_clone = c.clone()
4236*da0073e9SAndroid Build Coastguard Worker        ind_empty = torch.tensor([], dtype=torch.int64, device=device)
4237*da0073e9SAndroid Build Coastguard Worker        ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device)
4238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
4239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1))
4240*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_fill_(2, ind_01, -1))
4241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
4242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
4243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
4244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
4245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
4246*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
4247*da0073e9SAndroid Build Coastguard Worker
4248*da0073e9SAndroid Build Coastguard Worker        c = torch.randn((0, 1, 2), device=device)
4249*da0073e9SAndroid Build Coastguard Worker        c_clone = c.clone()
4250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
4251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
4254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
4256*da0073e9SAndroid Build Coastguard Worker
4257*da0073e9SAndroid Build Coastguard Worker        # index fill/copy/add non-empty
4258*da0073e9SAndroid Build Coastguard Worker        z = torch.randn((2, 3, 4), device=device)
4259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z.index_fill_(0, ind_empty, -1))
4260*da0073e9SAndroid Build Coastguard Worker        z = torch.randn((2, 3, 4), device=device)
4261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
4262*da0073e9SAndroid Build Coastguard Worker        z = torch.randn((2, 3, 4), device=device)
4263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
4264*da0073e9SAndroid Build Coastguard Worker
4265*da0073e9SAndroid Build Coastguard Worker        # index_select
4266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.index_select(0, ind_empty))
4267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape)
4268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.index_select(2, ind_01))
4269*da0073e9SAndroid Build Coastguard Worker        z = torch.randn((2, 3, 4), device=device)  # non-empty
4270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape)
4271*da0073e9SAndroid Build Coastguard Worker        c = torch.randn((0, 1, 2), device=device)
4272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, c.index_select(0, ind_empty))
4273*da0073e9SAndroid Build Coastguard Worker        c = torch.randn((0, 1, 2), device=device)
4274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, c.index_select(0, ind_empty))
4275*da0073e9SAndroid Build Coastguard Worker        w = torch.randn((0, 3), device=device)
4276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 2), w.index_select(1, ind_01).shape)
4277*da0073e9SAndroid Build Coastguard Worker        w = torch.randn((3, 0), device=device)
4278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((2, 0), w.index_select(0, ind_01).shape)
4279*da0073e9SAndroid Build Coastguard Worker        ind_01_int32 = torch.tensor([0, 1], dtype=torch.int32, device=device)
4280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((2, 0), w.index_select(0, ind_01_int32).shape)
4281*da0073e9SAndroid Build Coastguard Worker        s = torch.randn([], device=device)
4282*da0073e9SAndroid Build Coastguard Worker        ind_0 = torch.tensor([0], dtype=torch.int32, device=device)
4283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([], s.index_select(0, ind_0).shape)
4284*da0073e9SAndroid Build Coastguard Worker        if device == 'cpu':
4285*da0073e9SAndroid Build Coastguard Worker            w = torch.randn((0, 3), device=device)
4286*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "self indexing axis dim should be positive"):
4287*da0073e9SAndroid Build Coastguard Worker                torch.index_select(w, 0, ind_01)
4288*da0073e9SAndroid Build Coastguard Worker            ind_05 = torch.tensor([0, 5], dtype=torch.int64, device=device)
4289*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "INDICES element is out of DATA bounds"):
4290*da0073e9SAndroid Build Coastguard Worker                torch.index_select(w, 1, ind_05)
4291*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
4292*da0073e9SAndroid Build Coastguard Worker                torch.index_select(s, 0, ind_empty)
4293*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
4294*da0073e9SAndroid Build Coastguard Worker            torch.ones([]).index_select(0, torch.Tensor([0, 0]).int())
4295*da0073e9SAndroid Build Coastguard Worker
4296*da0073e9SAndroid Build Coastguard Worker    # FIXME: find a test suite for the pdist operator
4297*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration")
4298*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
4299*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4300*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('32GB', device='cpu')
4301*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('5GB', device='cuda')
4302*da0073e9SAndroid Build Coastguard Worker    def test_pdist_norm_large(self, device):
4303*da0073e9SAndroid Build Coastguard Worker        # use dim0>=46342 for forward, see:
4304*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/30583
4305*da0073e9SAndroid Build Coastguard Worker        # Compare output using GPU with the CPU implementation
4306*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(50000, 1, dtype=torch.float32)      # 50k * 4 bytes = 200 KB
4307*da0073e9SAndroid Build Coastguard Worker        # Will require 1249975000 float32s
4308*da0073e9SAndroid Build Coastguard Worker        expected_cpu = torch.pdist(x, p=2)                  # ~1250M * 4 bytes = 5 GB on CPU
4309*da0073e9SAndroid Build Coastguard Worker        actual_cpu = torch.pdist(x.to(device), p=2).cpu()         # 5 GB on GPU + 5GB on CPU
4310*da0073e9SAndroid Build Coastguard Worker        # Workaround for large memory overhead of self.assertTrue (see #84944)
4311*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(expected_cpu, actual_cpu))  # ~20GB in allclose
4312*da0073e9SAndroid Build Coastguard Worker
4313*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to elementwise ternary test suite
4314*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4315*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
4316*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(get_all_math_dtypes('cpu')))
4317*da0073e9SAndroid Build Coastguard Worker    def test_addcdiv(self, device, dtype):
4318*da0073e9SAndroid Build Coastguard Worker        # Returns floating or integral scalar corresponding to dtype
4319*da0073e9SAndroid Build Coastguard Worker        def _number(floating, integer, dtype):
4320*da0073e9SAndroid Build Coastguard Worker            if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
4321*da0073e9SAndroid Build Coastguard Worker                return floating
4322*da0073e9SAndroid Build Coastguard Worker            elif dtype in [torch.cfloat, torch.cdouble]:
4323*da0073e9SAndroid Build Coastguard Worker                return floating * (1 + 1j)
4324*da0073e9SAndroid Build Coastguard Worker            else:
4325*da0073e9SAndroid Build Coastguard Worker                return integer
4326*da0073e9SAndroid Build Coastguard Worker
4327*da0073e9SAndroid Build Coastguard Worker        def non_zero_rand(size, dtype, device):
4328*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
4329*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(size=size, dtype=dtype, device=device)
4330*da0073e9SAndroid Build Coastguard Worker            elif dtype == torch.uint8:
4331*da0073e9SAndroid Build Coastguard Worker                a = torch.randint(1, 5, size=size, dtype=dtype, device=device)
4332*da0073e9SAndroid Build Coastguard Worker            else:
4333*da0073e9SAndroid Build Coastguard Worker                a = torch.randint(-5, 5, size=size, dtype=dtype, device=device)
4334*da0073e9SAndroid Build Coastguard Worker            return a + (a == 0).to(dtype)
4335*da0073e9SAndroid Build Coastguard Worker
4336*da0073e9SAndroid Build Coastguard Worker        def _test_addcdiv():
4337*da0073e9SAndroid Build Coastguard Worker            a = non_zero_rand((2, 2), dtype=dtype, device=device)
4338*da0073e9SAndroid Build Coastguard Worker            b = non_zero_rand((2, 2), dtype=dtype, device=device)
4339*da0073e9SAndroid Build Coastguard Worker            c = non_zero_rand((2, 2), dtype=dtype, device=device)
4340*da0073e9SAndroid Build Coastguard Worker            alpha = _number(0.5, 3, dtype)
4341*da0073e9SAndroid Build Coastguard Worker
4342*da0073e9SAndroid Build Coastguard Worker            expected = a + (alpha * b) / c
4343*da0073e9SAndroid Build Coastguard Worker            actual = torch.addcdiv(a, b, c, value=alpha)
4344*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
4345*da0073e9SAndroid Build Coastguard Worker
4346*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsOnceRegex(
4347*da0073e9SAndroid Build Coastguard Worker                    UserWarning, "This overload of addcdiv is deprecated"):
4348*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, torch.addcdiv(a, alpha, b, c))
4349*da0073e9SAndroid Build Coastguard Worker
4350*da0073e9SAndroid Build Coastguard Worker        if not (dtype.is_floating_point or dtype.is_complex):
4351*da0073e9SAndroid Build Coastguard Worker            # Integer division with addcdiv is prohibited
4352*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
4353*da0073e9SAndroid Build Coastguard Worker                _test_addcdiv()
4354*da0073e9SAndroid Build Coastguard Worker        else:
4355*da0073e9SAndroid Build Coastguard Worker            _test_addcdiv()
4356*da0073e9SAndroid Build Coastguard Worker
4357*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and dtype == torch.half:
4358*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([60000.0], device=device, dtype=dtype)
4359*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([60000.0], device=device, dtype=dtype)
4360*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor([1.0], device=device, dtype=dtype)
4361*da0073e9SAndroid Build Coastguard Worker            out = torch.addcmul(a, b, c, value=-2)
4362*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not (out.isnan() or out.isinf()))
4363*da0073e9SAndroid Build Coastguard Worker
4364*da0073e9SAndroid Build Coastguard Worker    def test_nullary_op_mem_overlap(self, device):
4365*da0073e9SAndroid Build Coastguard Worker        ops = (
4366*da0073e9SAndroid Build Coastguard Worker            ("random_", ()),
4367*da0073e9SAndroid Build Coastguard Worker            ("uniform_", ()),
4368*da0073e9SAndroid Build Coastguard Worker            ("cauchy_", ()),
4369*da0073e9SAndroid Build Coastguard Worker            ("log_normal_", ()),
4370*da0073e9SAndroid Build Coastguard Worker            ("exponential_", ()),
4371*da0073e9SAndroid Build Coastguard Worker            ("geometric_", (0.5,)),
4372*da0073e9SAndroid Build Coastguard Worker            ("normal_", ()),
4373*da0073e9SAndroid Build Coastguard Worker        )
4374*da0073e9SAndroid Build Coastguard Worker
4375*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 3)).expand((3, 3))
4376*da0073e9SAndroid Build Coastguard Worker        for op, args in ops:
4377*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4378*da0073e9SAndroid Build Coastguard Worker                getattr(x, op)(*args)
4379*da0073e9SAndroid Build Coastguard Worker
4380*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to an elementwise ternary test suite and make this an OpInfo test
4381*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/126474
4382*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
4383*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/126474")
4384*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
4385*da0073e9SAndroid Build Coastguard Worker    def test_ternary_op_mem_overlap(self, device, dtype):
4386*da0073e9SAndroid Build Coastguard Worker        if device == "cpu" and TEST_WITH_TORCHINDUCTOR:
4387*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Failing on cpu")
4388*da0073e9SAndroid Build Coastguard Worker
4389*da0073e9SAndroid Build Coastguard Worker        ops = [
4390*da0073e9SAndroid Build Coastguard Worker            ("addcmul", True, True, 'cpu'),
4391*da0073e9SAndroid Build Coastguard Worker            ("addcmul", True, True, 'cuda'),
4392*da0073e9SAndroid Build Coastguard Worker            ("addcdiv", True, True, 'cpu'),
4393*da0073e9SAndroid Build Coastguard Worker            ("addcdiv", True, True, 'cuda'),
4394*da0073e9SAndroid Build Coastguard Worker            ("lerp", True, True, 'cpu'),
4395*da0073e9SAndroid Build Coastguard Worker            ("lerp", True, True, 'cuda')
4396*da0073e9SAndroid Build Coastguard Worker        ]
4397*da0073e9SAndroid Build Coastguard Worker
4398*da0073e9SAndroid Build Coastguard Worker        for (fn, has_input_output_mem_overlap_check,
4399*da0073e9SAndroid Build Coastguard Worker             has_internal_mem_overlap_check, dev) in ops:
4400*da0073e9SAndroid Build Coastguard Worker            if dev != device:
4401*da0073e9SAndroid Build Coastguard Worker                continue
4402*da0073e9SAndroid Build Coastguard Worker            out_op = getattr(torch, fn)
4403*da0073e9SAndroid Build Coastguard Worker            inplace_op = getattr(torch.Tensor, fn + '_')
4404*da0073e9SAndroid Build Coastguard Worker            self.check_internal_mem_overlap(
4405*da0073e9SAndroid Build Coastguard Worker                inplace_op, 3, dtype, device,
4406*da0073e9SAndroid Build Coastguard Worker                expected_failure=not has_internal_mem_overlap_check)
4407*da0073e9SAndroid Build Coastguard Worker            self.ternary_check_input_output_mem_overlap(out_op, dev,
4408*da0073e9SAndroid Build Coastguard Worker                                                        expected_failure=not has_input_output_mem_overlap_check)
4409*da0073e9SAndroid Build Coastguard Worker
4410*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError not raised
4411*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
4412*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4413*da0073e9SAndroid Build Coastguard Worker    def test_copy_mem_overlap(self, device, dtype):
4414*da0073e9SAndroid Build Coastguard Worker        self.check_internal_mem_overlap(
4415*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device)
4416*da0073e9SAndroid Build Coastguard Worker        sz = 9
4417*da0073e9SAndroid Build Coastguard Worker        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
4418*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
4419*da0073e9SAndroid Build Coastguard Worker            doubles, sz, lambda input, out: out.copy_(input))
4420*da0073e9SAndroid Build Coastguard Worker
4421*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4422*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4423*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4424*da0073e9SAndroid Build Coastguard Worker    def test_index_add_mem_overlap(self, device):
4425*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4426*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((6,), device=device)
4427*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([2, 1, 0], device=device)
4428*da0073e9SAndroid Build Coastguard Worker        value = torch.rand((3,), device=device)
4429*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4430*da0073e9SAndroid Build Coastguard Worker            x.index_add_(0, ind, value)
4431*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4432*da0073e9SAndroid Build Coastguard Worker            y.index_add_(0, ind, y[:3])
4433*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4434*da0073e9SAndroid Build Coastguard Worker            ind.index_add_(0, ind, ind.clone())
4435*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4436*da0073e9SAndroid Build Coastguard Worker            ind.index_add_(0, ind.clone(), ind)
4437*da0073e9SAndroid Build Coastguard Worker
4438*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4439*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4440*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4441*da0073e9SAndroid Build Coastguard Worker    def test_index_copy_mem_overlap(self, device):
4442*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4443*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((6,), device=device)
4444*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([2, 1, 0], device=device)
4445*da0073e9SAndroid Build Coastguard Worker        value = torch.rand((3,), device=device)
4446*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4447*da0073e9SAndroid Build Coastguard Worker            x.index_copy_(0, ind, value)
4448*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4449*da0073e9SAndroid Build Coastguard Worker            y.index_copy_(0, ind, y[:3])
4450*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4451*da0073e9SAndroid Build Coastguard Worker            ind.index_copy_(0, ind, ind.clone())
4452*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4453*da0073e9SAndroid Build Coastguard Worker            ind.index_copy_(0, ind.clone(), ind)
4454*da0073e9SAndroid Build Coastguard Worker
4455*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4456*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4457*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # Warning not triggered
4458*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4459*da0073e9SAndroid Build Coastguard Worker    def test_index_fill_mem_overlap(self, device):
4460*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4461*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((6,), device=device)
4462*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([2, 1, 0], device=device)
4463*da0073e9SAndroid Build Coastguard Worker        value = torch.rand((3,), device=device)
4464*da0073e9SAndroid Build Coastguard Worker
4465*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "index_fill_ on expanded tensors"):
4466*da0073e9SAndroid Build Coastguard Worker            x.index_fill_(0, ind, 1.0)
4467*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4468*da0073e9SAndroid Build Coastguard Worker            ind.index_fill_(0, ind, 0)
4469*da0073e9SAndroid Build Coastguard Worker
4470*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4471*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError not raised
4472*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4473*da0073e9SAndroid Build Coastguard Worker    def test_shift_mem_overlap(self, device):
4474*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, device=device)
4475*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4476*da0073e9SAndroid Build Coastguard Worker            x[:-1] <<= x[1:]
4477*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4478*da0073e9SAndroid Build Coastguard Worker            x[:-1] >>= x[1:]
4479*da0073e9SAndroid Build Coastguard Worker
4480*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4481*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors)
4482*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError not raised
4483*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4484*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_mem_overlap(self, device):
4485*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4486*da0073e9SAndroid Build Coastguard Worker
4487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4488*da0073e9SAndroid Build Coastguard Worker            x.bernoulli_()
4489*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4490*da0073e9SAndroid Build Coastguard Worker            x.bernoulli_(p=0.1)
4491*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(6, device=device)
4492*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4493*da0073e9SAndroid Build Coastguard Worker            x.bernoulli_(p=p)
4494*da0073e9SAndroid Build Coastguard Worker
4495*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4496*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4497*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError not raised
4498*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4499*da0073e9SAndroid Build Coastguard Worker    def test_put_mem_overlap(self, device):
4500*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4501*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((6,), device=device)
4502*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([2, 1, 0], device=device)
4503*da0073e9SAndroid Build Coastguard Worker        value = torch.rand((3,), device=device)
4504*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4505*da0073e9SAndroid Build Coastguard Worker            x.put_(ind, value)
4506*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4507*da0073e9SAndroid Build Coastguard Worker            y.put_(ind[0], y[0])
4508*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4509*da0073e9SAndroid Build Coastguard Worker            ind.put_(ind, ind)
4510*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4511*da0073e9SAndroid Build Coastguard Worker            y.put_(ind, y[:3])
4512*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4513*da0073e9SAndroid Build Coastguard Worker            ind.put_(ind, ind.clone())
4514*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4515*da0073e9SAndroid Build Coastguard Worker            ind.put_(ind.clone(), ind)
4516*da0073e9SAndroid Build Coastguard Worker
4517*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4518*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4519*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # UserWarning not triggered
4520*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4521*da0073e9SAndroid Build Coastguard Worker    def test_index_put_mem_overlap(self, device):
4522*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4523*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((6,), device=device)
4524*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([2, 1, 0], device=device)
4525*da0073e9SAndroid Build Coastguard Worker        value = torch.rand((3,), device=device)
4526*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4527*da0073e9SAndroid Build Coastguard Worker            x.index_put_((ind,), value)
4528*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4529*da0073e9SAndroid Build Coastguard Worker            y.index_put_((ind,), y[0])
4530*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4531*da0073e9SAndroid Build Coastguard Worker            ind.index_put_((ind,), ind)
4532*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4533*da0073e9SAndroid Build Coastguard Worker            y.index_put_((ind,), y[:3])
4534*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4535*da0073e9SAndroid Build Coastguard Worker            ind.index_put_((ind,), ind.clone())
4536*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4537*da0073e9SAndroid Build Coastguard Worker            ind.index_put_((ind.clone(),), ind)
4538*da0073e9SAndroid Build Coastguard Worker
4539*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4540*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4541*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # UserWarning not triggered
4542*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4543*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill_mem_overlap(self, device):
4544*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4545*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([True, False, True, True, False, False], device=device)
4546*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4547*da0073e9SAndroid Build Coastguard Worker            x.masked_fill_(mask, 0.)
4548*da0073e9SAndroid Build Coastguard Worker
4549*da0073e9SAndroid Build Coastguard Worker        fill_val = torch.tensor(0., device=device)
4550*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4551*da0073e9SAndroid Build Coastguard Worker            x.masked_fill_(mask, fill_val)
4552*da0073e9SAndroid Build Coastguard Worker
4553*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4554*da0073e9SAndroid Build Coastguard Worker            mask[1:].masked_fill_(mask[:-1], False)
4555*da0073e9SAndroid Build Coastguard Worker
4556*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4557*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4558*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError not raised
4559*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4560*da0073e9SAndroid Build Coastguard Worker    def test_masked_scatter_mem_overlap(self, device):
4561*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4562*da0073e9SAndroid Build Coastguard Worker        src = torch.rand((3,), device=device)
4563*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([True, False, True, True, False, False], device=device)
4564*da0073e9SAndroid Build Coastguard Worker
4565*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4566*da0073e9SAndroid Build Coastguard Worker            x.masked_scatter_(mask, src)
4567*da0073e9SAndroid Build Coastguard Worker
4568*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert to ErrorInputs
4569*da0073e9SAndroid Build Coastguard Worker    # (but have to extend ErrorInputs to handle inplace-only errors!)
4570*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4571*da0073e9SAndroid Build Coastguard Worker    def test_scatter_mem_overlap(self, device):
4572*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device).expand((6,))
4573*da0073e9SAndroid Build Coastguard Worker        src = torch.rand((3,), device=device)
4574*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
4575*da0073e9SAndroid Build Coastguard Worker
4576*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4577*da0073e9SAndroid Build Coastguard Worker            x.scatter_(0, ind, src)
4578*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4579*da0073e9SAndroid Build Coastguard Worker            src.scatter_(0, ind, src)
4580*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
4581*da0073e9SAndroid Build Coastguard Worker            ind.scatter_(0, ind, ind.clone())
4582*da0073e9SAndroid Build Coastguard Worker
4583*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
4584*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4585*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_device_constrain(self, device):
4586*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(3, device="cpu")
4587*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(3, device=device)
4588*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
4589*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device",
4590*da0073e9SAndroid Build Coastguard Worker            lambda: torch.multinomial(x, 2, out=y))
4591*da0073e9SAndroid Build Coastguard Worker
4592*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
4593*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
4594*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4595*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("FIXME: error not thrown")
4596*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_gpu_device_constrain(self, devices):
4597*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(3, device=devices[0])
4598*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(3, device=devices[1], dtype=torch.long)
4599*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
4600*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device",
4601*da0073e9SAndroid Build Coastguard Worker            lambda: torch.multinomial(x, 2, out=y))
4602*da0073e9SAndroid Build Coastguard Worker
4603*da0073e9SAndroid Build Coastguard Worker    # FIXME: convert this to an automated OpInfo test
4604*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
4605*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4606*da0073e9SAndroid Build Coastguard Worker    def test_device_guard(self, devices):
4607*da0073e9SAndroid Build Coastguard Worker        # verify that all operators with `device_guard: False` behave properly with multiple devices.
4608*da0073e9SAndroid Build Coastguard Worker        # TODO: if we had operator introspection we could figure out this set of operators automatically...
4609*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 2, 3), device=devices[1])
4610*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros((1, 3, 2), device=devices[1])
4611*da0073e9SAndroid Build Coastguard Worker        scalar = torch.tensor(5, device=devices[1])
4612*da0073e9SAndroid Build Coastguard Worker
4613*da0073e9SAndroid Build Coastguard Worker        # property ops
4614*da0073e9SAndroid Build Coastguard Worker        torch.cudnn_is_acceptable(x)
4615*da0073e9SAndroid Build Coastguard Worker        x.is_distributed()
4616*da0073e9SAndroid Build Coastguard Worker        x.is_floating_point()
4617*da0073e9SAndroid Build Coastguard Worker        x.is_complex()
4618*da0073e9SAndroid Build Coastguard Worker        x.is_same_size(y)
4619*da0073e9SAndroid Build Coastguard Worker        x.is_signed()
4620*da0073e9SAndroid Build Coastguard Worker        x.size(0)
4621*da0073e9SAndroid Build Coastguard Worker        x.stride(0)
4622*da0073e9SAndroid Build Coastguard Worker        x.numel()
4623*da0073e9SAndroid Build Coastguard Worker        x.is_set_to(y)
4624*da0073e9SAndroid Build Coastguard Worker        x.data_ptr()
4625*da0073e9SAndroid Build Coastguard Worker        scalar.is_nonzero()
4626*da0073e9SAndroid Build Coastguard Worker
4627*da0073e9SAndroid Build Coastguard Worker        # sparse property ops
4628*da0073e9SAndroid Build Coastguard Worker        y[0][1] = 5
4629*da0073e9SAndroid Build Coastguard Worker        y_sparse = y.to_sparse()
4630*da0073e9SAndroid Build Coastguard Worker        y_sparse.sparse_dim()
4631*da0073e9SAndroid Build Coastguard Worker        y_sparse._dimI()
4632*da0073e9SAndroid Build Coastguard Worker        y_sparse.dense_dim()
4633*da0073e9SAndroid Build Coastguard Worker        y_sparse._dimV()
4634*da0073e9SAndroid Build Coastguard Worker        y_sparse._nnz()
4635*da0073e9SAndroid Build Coastguard Worker        y_sparse.is_coalesced()
4636*da0073e9SAndroid Build Coastguard Worker        y_sparse._indices()
4637*da0073e9SAndroid Build Coastguard Worker        y_sparse._values()
4638*da0073e9SAndroid Build Coastguard Worker        y_sparse.indices()
4639*da0073e9SAndroid Build Coastguard Worker        y_sparse.values()
4640*da0073e9SAndroid Build Coastguard Worker
4641*da0073e9SAndroid Build Coastguard Worker        # in-place ops
4642*da0073e9SAndroid Build Coastguard Worker        def inplace():
4643*da0073e9SAndroid Build Coastguard Worker            return torch.randn((1, 2, 3), device=devices[1])
4644*da0073e9SAndroid Build Coastguard Worker        inplace().as_strided_(y.size(), y.stride())
4645*da0073e9SAndroid Build Coastguard Worker        inplace().resize_(y.size())
4646*da0073e9SAndroid Build Coastguard Worker        inplace().squeeze_()
4647*da0073e9SAndroid Build Coastguard Worker        inplace().squeeze_(0)
4648*da0073e9SAndroid Build Coastguard Worker        inplace().unsqueeze_(2)
4649*da0073e9SAndroid Build Coastguard Worker        inplace().transpose_(1, 2)
4650*da0073e9SAndroid Build Coastguard Worker        inplace().squeeze_().t_()
4651*da0073e9SAndroid Build Coastguard Worker        inplace().set_(x.storage())
4652*da0073e9SAndroid Build Coastguard Worker        inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride())
4653*da0073e9SAndroid Build Coastguard Worker        inplace().set_(x)
4654*da0073e9SAndroid Build Coastguard Worker        inplace().set_()
4655*da0073e9SAndroid Build Coastguard Worker        y_sparse._coalesced_(True)
4656*da0073e9SAndroid Build Coastguard Worker
4657*da0073e9SAndroid Build Coastguard Worker        # shape modification
4658*da0073e9SAndroid Build Coastguard Worker        x.as_strided(y.size(), y.stride())
4659*da0073e9SAndroid Build Coastguard Worker        x.expand((5, 2, 3))
4660*da0073e9SAndroid Build Coastguard Worker        x.expand_as(x)
4661*da0073e9SAndroid Build Coastguard Worker        x.sum_to_size((1,))
4662*da0073e9SAndroid Build Coastguard Worker        torch.broadcast_tensors(x , x)
4663*da0073e9SAndroid Build Coastguard Worker        x.reshape((1, 3, 2))
4664*da0073e9SAndroid Build Coastguard Worker        x.reshape_as(y)
4665*da0073e9SAndroid Build Coastguard Worker        x.squeeze()
4666*da0073e9SAndroid Build Coastguard Worker        x.squeeze(0)
4667*da0073e9SAndroid Build Coastguard Worker        x.squeeze().t()
4668*da0073e9SAndroid Build Coastguard Worker        x.transpose(1, 2)
4669*da0073e9SAndroid Build Coastguard Worker        x.unsqueeze(2)
4670*da0073e9SAndroid Build Coastguard Worker        x.view((1, 3, 2))
4671*da0073e9SAndroid Build Coastguard Worker        x.view_as(y)
4672*da0073e9SAndroid Build Coastguard Worker
4673*da0073e9SAndroid Build Coastguard Worker        # chunk, split, etc.
4674*da0073e9SAndroid Build Coastguard Worker        x.chunk(2, dim=1)
4675*da0073e9SAndroid Build Coastguard Worker        x.split(1, dim=2)
4676*da0073e9SAndroid Build Coastguard Worker        x.split_with_sizes([1, 2], dim=2)
4677*da0073e9SAndroid Build Coastguard Worker        x.unfold(dimension=2, size=1, step=1)
4678*da0073e9SAndroid Build Coastguard Worker
4679*da0073e9SAndroid Build Coastguard Worker        x.narrow(1, 1, 1)
4680*da0073e9SAndroid Build Coastguard Worker        x.select(1, 1)
4681*da0073e9SAndroid Build Coastguard Worker        torch.isnan(x)
4682*da0073e9SAndroid Build Coastguard Worker
4683*da0073e9SAndroid Build Coastguard Worker        torch.empty((1, 3, 2), out=y)
4684*da0073e9SAndroid Build Coastguard Worker        torch.empty_like(x)
4685*da0073e9SAndroid Build Coastguard Worker        torch.empty_like(x, dtype=torch.int64)
4686*da0073e9SAndroid Build Coastguard Worker
4687*da0073e9SAndroid Build Coastguard Worker        # to
4688*da0073e9SAndroid Build Coastguard Worker        x.to(x)
4689*da0073e9SAndroid Build Coastguard Worker        x.to(y)
4690*da0073e9SAndroid Build Coastguard Worker        x.to(x, copy=True)
4691*da0073e9SAndroid Build Coastguard Worker
4692*da0073e9SAndroid Build Coastguard Worker    def test_is_signed(self, device):
4693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.IntTensor(5).to(device).is_signed(), True)
4694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ByteTensor(5).to(device).is_signed(), False)
4695*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.CharTensor(5).to(device).is_signed(), True)
4696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.FloatTensor(5).to(device).is_signed(), True)
4697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.HalfTensor(10).to(device).is_signed(), True)
4698*da0073e9SAndroid Build Coastguard Worker
4699*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type(self):
4700*da0073e9SAndroid Build Coastguard Worker        for t in torch._tensor_classes:
4701*da0073e9SAndroid Build Coastguard Worker            if 'cuda' in t.__module__:
4702*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.is_cuda, True)
4703*da0073e9SAndroid Build Coastguard Worker            else:
4704*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.is_cuda, False)
4705*da0073e9SAndroid Build Coastguard Worker            if 'xpu' in t.__module__:
4706*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.is_xpu, True)
4707*da0073e9SAndroid Build Coastguard Worker            else:
4708*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.is_xpu, False)
4709*da0073e9SAndroid Build Coastguard Worker
4710*da0073e9SAndroid Build Coastguard Worker    # Note - reports a leak of 512 bytes on CUDA device 1
4711*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
4712*da0073e9SAndroid Build Coastguard Worker    @skipCUDAMemoryLeakCheckIf(True)
4713*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4714*da0073e9SAndroid Build Coastguard Worker    def test_tensor_set_errors_multigpu(self, devices):
4715*da0073e9SAndroid Build Coastguard Worker        f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device=devices[0])
4716*da0073e9SAndroid Build Coastguard Worker        f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device=devices[1])
4717*da0073e9SAndroid Build Coastguard Worker
4718*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage()))
4719*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError,
4720*da0073e9SAndroid Build Coastguard Worker                          lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride()))
4721*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1))
4722*da0073e9SAndroid Build Coastguard Worker
4723*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_serialization
4724*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4725*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)  # Note: Tests works with one but prefers more devices
4726*da0073e9SAndroid Build Coastguard Worker    def test_serialization(self, devices):
4727*da0073e9SAndroid Build Coastguard Worker        def _test_serialization(filecontext_lambda):
4728*da0073e9SAndroid Build Coastguard Worker            t0 = torch.cuda.FloatTensor(5).fill_(1)
4729*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(devices[-1]):
4730*da0073e9SAndroid Build Coastguard Worker                tn = torch.cuda.FloatTensor(3).fill_(2)
4731*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(devices[0])
4732*da0073e9SAndroid Build Coastguard Worker            b = (t0, tn)
4733*da0073e9SAndroid Build Coastguard Worker            with filecontext_lambda() as f:
4734*da0073e9SAndroid Build Coastguard Worker                torch.save(b, f)
4735*da0073e9SAndroid Build Coastguard Worker                f.seek(0)
4736*da0073e9SAndroid Build Coastguard Worker                c = torch.load(f)
4737*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b, c, atol=0, rtol=0)
4738*da0073e9SAndroid Build Coastguard Worker                u0, un = c
4739*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(str(u0.device), devices[0])
4740*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(str(un.device), devices[-1])
4741*da0073e9SAndroid Build Coastguard Worker
4742*da0073e9SAndroid Build Coastguard Worker        _test_serialization(tempfile.NamedTemporaryFile)
4743*da0073e9SAndroid Build Coastguard Worker        _test_serialization(BytesIOContext)
4744*da0073e9SAndroid Build Coastguard Worker
4745*da0073e9SAndroid Build Coastguard Worker    # FIXME: move memory format tests to their own test class/suite
4746*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_preserved_after_permute(self, device):
4747*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 3, 8, 8, device=device)
4748*da0073e9SAndroid Build Coastguard Worker        nhwc = x.contiguous(memory_format=torch.channels_last)
4749*da0073e9SAndroid Build Coastguard Worker        y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
4750*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
4751*da0073e9SAndroid Build Coastguard Worker
4752*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 3, 8, 8, 8, device=device)
4753*da0073e9SAndroid Build Coastguard Worker        ndhwc = x.contiguous(memory_format=torch.channels_last_3d)
4754*da0073e9SAndroid Build Coastguard Worker        y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2)
4755*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
4756*da0073e9SAndroid Build Coastguard Worker
4757*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_propagation_rules(self, device):
4758*da0073e9SAndroid Build Coastguard Worker
4759*da0073e9SAndroid Build Coastguard Worker        contiguous = torch.rand(10, 3, 5, 5, device=device)
4760*da0073e9SAndroid Build Coastguard Worker        cl = torch.rand(10, 3, 5, 5, device=device).contiguous(memory_format=torch.channels_last)
4761*da0073e9SAndroid Build Coastguard Worker        ambiguous = torch.rand(10, 3, 1, 1, device=device).contiguous(memory_format=torch.channels_last)
4762*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ambiguous.is_contiguous(memory_format=torch.channels_last))
4763*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ambiguous.is_contiguous(memory_format=torch.contiguous_format))
4764*da0073e9SAndroid Build Coastguard Worker        bias = torch.rand(1, 1, 1, 1, device=device).contiguous(memory_format=torch.channels_last)
4765*da0073e9SAndroid Build Coastguard Worker
4766*da0073e9SAndroid Build Coastguard Worker        def _test_propagation_rules(self, contiguous, cl, ambiguous, bias):
4767*da0073e9SAndroid Build Coastguard Worker            options = ((ambiguous, contiguous, torch.contiguous_format),
4768*da0073e9SAndroid Build Coastguard Worker                       (ambiguous, cl, torch.channels_last),
4769*da0073e9SAndroid Build Coastguard Worker                       (contiguous, ambiguous, torch.contiguous_format),
4770*da0073e9SAndroid Build Coastguard Worker                       (contiguous, cl, torch.contiguous_format),
4771*da0073e9SAndroid Build Coastguard Worker                       (cl, ambiguous, torch.channels_last),
4772*da0073e9SAndroid Build Coastguard Worker                       (cl, contiguous, torch.channels_last),
4773*da0073e9SAndroid Build Coastguard Worker                       (bias, cl, torch.channels_last),
4774*da0073e9SAndroid Build Coastguard Worker                       (cl, bias, torch.channels_last),)
4775*da0073e9SAndroid Build Coastguard Worker
4776*da0073e9SAndroid Build Coastguard Worker            for a, b, mf in options:
4777*da0073e9SAndroid Build Coastguard Worker                result = a + b
4778*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(result.is_contiguous(memory_format=mf))
4779*da0073e9SAndroid Build Coastguard Worker
4780*da0073e9SAndroid Build Coastguard Worker        _test_propagation_rules(self, contiguous, cl, ambiguous, bias)
4781*da0073e9SAndroid Build Coastguard Worker
4782*da0073e9SAndroid Build Coastguard Worker        cl = cl.to(memory_format=torch.channels_last)
4783*da0073e9SAndroid Build Coastguard Worker        ambiguous = ambiguous.to(memory_format=torch.channels_last)
4784*da0073e9SAndroid Build Coastguard Worker        bias = bias.to(memory_format=torch.channels_last)
4785*da0073e9SAndroid Build Coastguard Worker
4786*da0073e9SAndroid Build Coastguard Worker        _test_propagation_rules(self, contiguous, cl, ambiguous, bias)
4787*da0073e9SAndroid Build Coastguard Worker
4788*da0073e9SAndroid Build Coastguard Worker        # test cases when strides matter in ambiguous tensors
4789*da0073e9SAndroid Build Coastguard Worker        for mf in (torch.channels_last, torch.contiguous_format):
4790*da0073e9SAndroid Build Coastguard Worker            ambiguous = torch.rand(10, 3, 1, 1, device=device).to(memory_format=mf)
4791*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(3, 1, 1, device=device)
4792*da0073e9SAndroid Build Coastguard Worker            result = ambiguous + bias
4793*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ambiguous.stride(), result.stride())
4794*da0073e9SAndroid Build Coastguard Worker            result = bias + ambiguous
4795*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ambiguous.stride(), result.stride())
4796*da0073e9SAndroid Build Coastguard Worker            result = ambiguous * 5
4797*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ambiguous.stride(), result.stride())
4798*da0073e9SAndroid Build Coastguard Worker
4799*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
4800*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_empty_like(self, device):
4801*da0073e9SAndroid Build Coastguard Worker        def test_helper(x, memory_format):
4802*da0073e9SAndroid Build Coastguard Worker            xc = x.contiguous(memory_format=memory_format)
4803*da0073e9SAndroid Build Coastguard Worker
4804*da0073e9SAndroid Build Coastguard Worker            like = torch.empty_like(xc, memory_format=torch.preserve_format)
4805*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(like.is_contiguous())
4806*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(like.is_contiguous(memory_format=memory_format))
4807*da0073e9SAndroid Build Coastguard Worker
4808*da0073e9SAndroid Build Coastguard Worker            like_x = torch.empty_like(x, memory_format=torch.preserve_format)
4809*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(like_x.is_contiguous())
4810*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(like_x.is_contiguous(memory_format=memory_format))
4811*da0073e9SAndroid Build Coastguard Worker
4812*da0073e9SAndroid Build Coastguard Worker            like = torch.empty_like(x, memory_format=memory_format)
4813*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(like.is_contiguous())
4814*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(like.is_contiguous(memory_format=memory_format))
4815*da0073e9SAndroid Build Coastguard Worker
4816*da0073e9SAndroid Build Coastguard Worker            like = torch.empty_like(xc, memory_format=torch.contiguous_format)
4817*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(like.is_contiguous())
4818*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(like.is_contiguous(memory_format=memory_format))
4819*da0073e9SAndroid Build Coastguard Worker
4820*da0073e9SAndroid Build Coastguard Worker            like = torch.empty_like(xc)
4821*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(like.is_contiguous())
4822*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(like.is_contiguous(memory_format=memory_format))
4823*da0073e9SAndroid Build Coastguard Worker
4824*da0073e9SAndroid Build Coastguard Worker            sparse = x.to_sparse()
4825*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
4826*da0073e9SAndroid Build Coastguard Worker                z = torch.empty_like(sparse, memory_format=torch.preserve_format)
4827*da0073e9SAndroid Build Coastguard Worker
4828*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.randn(4, 3, 8, 8, device=device), torch.channels_last)
4829*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.randn(4, 3, 8, 8, 8, device=device), torch.channels_last_3d)
4830*da0073e9SAndroid Build Coastguard Worker
4831*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_consistency(self, device):
4832*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 3, 1, 1, device=device)
4833*da0073e9SAndroid Build Coastguard Worker        x_rep = x.as_strided(x.size(), x.stride())
4834*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.size(), x_rep.size())
4835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.stride(), x_rep.stride())
4836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.is_contiguous(), x_rep.is_contiguous())
4837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.is_contiguous(memory_format=torch.channels_last), x_rep.is_contiguous(memory_format=torch.channels_last))
4838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4839*da0073e9SAndroid Build Coastguard Worker            x.is_contiguous(memory_format=torch.channels_last_3d), x_rep.is_contiguous(memory_format=torch.channels_last_3d))
4840*da0073e9SAndroid Build Coastguard Worker
4841*da0073e9SAndroid Build Coastguard Worker    # FIXME: make this a elementwise unary and elementwise binary OpInfo test
4842*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_operators(self, device):
4843*da0073e9SAndroid Build Coastguard Worker        def _chunk_op(x, y):
4844*da0073e9SAndroid Build Coastguard Worker            x1, x2 = x.chunk(2, dim=1)
4845*da0073e9SAndroid Build Coastguard Worker            return x1 + x2
4846*da0073e9SAndroid Build Coastguard Worker
4847*da0073e9SAndroid Build Coastguard Worker        def _unsqueeze_op_add(x, y):
4848*da0073e9SAndroid Build Coastguard Worker            return x[0].unsqueeze(0) + 3
4849*da0073e9SAndroid Build Coastguard Worker
4850*da0073e9SAndroid Build Coastguard Worker        def _unsqueeze_op_clone(x, y):
4851*da0073e9SAndroid Build Coastguard Worker            return x[0].unsqueeze(0).clone()
4852*da0073e9SAndroid Build Coastguard Worker
4853*da0073e9SAndroid Build Coastguard Worker        def _test_helper(x, y, bias, memory_format):
4854*da0073e9SAndroid Build Coastguard Worker            return_contig_fns = [
4855*da0073e9SAndroid Build Coastguard Worker                lambda x, y: y + x,
4856*da0073e9SAndroid Build Coastguard Worker                lambda x, y: y * x,
4857*da0073e9SAndroid Build Coastguard Worker                lambda x, y: y.addcdiv(x, y, value=2),
4858*da0073e9SAndroid Build Coastguard Worker                lambda x, y: y.addcmul(x, y, value=2),
4859*da0073e9SAndroid Build Coastguard Worker            ]
4860*da0073e9SAndroid Build Coastguard Worker            bias_fns = [
4861*da0073e9SAndroid Build Coastguard Worker                lambda x, b: x + b,
4862*da0073e9SAndroid Build Coastguard Worker                lambda x, b: b + x,
4863*da0073e9SAndroid Build Coastguard Worker            ]
4864*da0073e9SAndroid Build Coastguard Worker            fns = [
4865*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.clone(),
4866*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x + 3,
4867*da0073e9SAndroid Build Coastguard Worker                lambda x, y: 3 * x,
4868*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x + y,
4869*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x * y,
4870*da0073e9SAndroid Build Coastguard Worker                lambda x, y: abs(x),
4871*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.abs(),
4872*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.abs_(),
4873*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.acos(),
4874*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.acos_(),
4875*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.add(y, alpha=3),
4876*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.add_(y, alpha=3),
4877*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.addcdiv(y, y, value=2),
4878*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.addcdiv_(y, y, value=2),
4879*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.addcmul(y, y, value=2),
4880*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.addcmul_(y, y, value=2),
4881*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.acosh(),
4882*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.acosh_(),
4883*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.asinh(),
4884*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.asinh_(),
4885*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.atanh(),
4886*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.atanh_(),
4887*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.asin(),
4888*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.asin_(),
4889*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.atan(),
4890*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.atan2(y),
4891*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.atan2_(y),
4892*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.ceil(),
4893*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.ceil_(),
4894*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.clamp(-1, 1),
4895*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.cos(),
4896*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.cosh(),
4897*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.div(0.5),
4898*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.div_(0.5),
4899*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.div(y),
4900*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.div_(y),
4901*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.digamma(),
4902*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.digamma_(),
4903*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.erf(),
4904*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.erfc(),
4905*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.erfinv(),
4906*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.erfinv_(),
4907*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.exp(),
4908*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.expm1(),
4909*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.expm1_(),
4910*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.floor(),
4911*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.floor_(),
4912*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.fmod(2),
4913*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.frac(),
4914*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.hypot(y),
4915*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.hypot_(y),
4916*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.i0(),
4917*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.i0_(),
4918*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.lerp(y, 0.5),
4919*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log(),
4920*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log_(),
4921*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log10(),
4922*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log10_(),
4923*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log1p(),
4924*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log1p_(),
4925*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log2(),
4926*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.log2_(),
4927*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.mul(3),
4928*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.mul_(3),
4929*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.neg(),
4930*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.neg_(),
4931*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.pow(3),
4932*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.pow_(3),
4933*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.pow(0.0),
4934*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.pow(1.0),
4935*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.reciprocal(),
4936*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.remainder(2),
4937*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.round(),
4938*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.round_(),
4939*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.rsqrt(),
4940*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.rsqrt_(),
4941*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sigmoid(),
4942*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sigmoid_(),
4943*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.logit(),
4944*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.logit_(),
4945*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.logit(1e-6),
4946*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.logit_(1e-6),
4947*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sign(),
4948*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sign_(),
4949*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sgn(),
4950*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sgn_(),
4951*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sin(),
4952*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sin_(),
4953*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sinh(),
4954*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sinh_(),
4955*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sqrt(),
4956*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.sqrt_(),
4957*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.tan(),
4958*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.tanh(),
4959*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.trunc(),
4960*da0073e9SAndroid Build Coastguard Worker                lambda x, y: x.trunc_(),
4961*da0073e9SAndroid Build Coastguard Worker                _chunk_op,
4962*da0073e9SAndroid Build Coastguard Worker                _unsqueeze_op_add,
4963*da0073e9SAndroid Build Coastguard Worker                _unsqueeze_op_clone,
4964*da0073e9SAndroid Build Coastguard Worker            ]
4965*da0073e9SAndroid Build Coastguard Worker            x_c = x.contiguous()
4966*da0073e9SAndroid Build Coastguard Worker            y_c = y.contiguous()
4967*da0073e9SAndroid Build Coastguard Worker            b_c = bias.contiguous()
4968*da0073e9SAndroid Build Coastguard Worker            for fn in fns:
4969*da0073e9SAndroid Build Coastguard Worker                is_inplace = '_(' in inspect.getsource(fn)
4970*da0073e9SAndroid Build Coastguard Worker                x_clone = x.clone() if is_inplace else x
4971*da0073e9SAndroid Build Coastguard Worker                x_c_clone = x_c.clone() if is_inplace else x_c
4972*da0073e9SAndroid Build Coastguard Worker                result_c = fn(x_c_clone, y_c)
4973*da0073e9SAndroid Build Coastguard Worker                result = fn(x_clone, y)
4974*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'")
4975*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
4976*da0073e9SAndroid Build Coastguard Worker                    result.is_contiguous(memory_format=memory_format),
4977*da0073e9SAndroid Build Coastguard Worker                    f"result of the '{inspect.getsource(fn).strip()}' is not in '{memory_format}' format")
4978*da0073e9SAndroid Build Coastguard Worker
4979*da0073e9SAndroid Build Coastguard Worker            for fn in bias_fns:
4980*da0073e9SAndroid Build Coastguard Worker                result_c = fn(x_c, b_c)
4981*da0073e9SAndroid Build Coastguard Worker                result = fn(x, bias)
4982*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'")
4983*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
4984*da0073e9SAndroid Build Coastguard Worker                    result.is_contiguous(memory_format=memory_format),
4985*da0073e9SAndroid Build Coastguard Worker                    f"result of the '{inspect.getsource(fn).strip()}' is not in '{memory_format}' format")
4986*da0073e9SAndroid Build Coastguard Worker
4987*da0073e9SAndroid Build Coastguard Worker            for fn in return_contig_fns:
4988*da0073e9SAndroid Build Coastguard Worker                result_c = fn(x_c, y_c)
4989*da0073e9SAndroid Build Coastguard Worker                result = fn(x, y)
4990*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'")
4991*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
4992*da0073e9SAndroid Build Coastguard Worker                    result.is_contiguous(memory_format=torch.contiguous_format),
4993*da0073e9SAndroid Build Coastguard Worker                    f"result of the '{inspect.getsource(fn).strip()}' is not in '{torch.contiguous_format}' format")
4994*da0073e9SAndroid Build Coastguard Worker
4995*da0073e9SAndroid Build Coastguard Worker        _test_helper(
4996*da0073e9SAndroid Build Coastguard Worker            torch.randn((4, 3, 8, 8), device=device).contiguous(memory_format=torch.channels_last),
4997*da0073e9SAndroid Build Coastguard Worker            abs(torch.randn((4, 3, 8, 8), device=device)) + 1,
4998*da0073e9SAndroid Build Coastguard Worker            torch.randn((1, 3, 1, 1), device=device).contiguous(memory_format=torch.channels_last),
4999*da0073e9SAndroid Build Coastguard Worker            torch.channels_last)
5000*da0073e9SAndroid Build Coastguard Worker        _test_helper(
5001*da0073e9SAndroid Build Coastguard Worker            torch.randn((4, 3, 8, 8, 8), device=device).contiguous(memory_format=torch.channels_last_3d),
5002*da0073e9SAndroid Build Coastguard Worker            abs(torch.randn((4, 3, 8, 8, 8), device=device)) + 1,
5003*da0073e9SAndroid Build Coastguard Worker            torch.randn((1, 3, 1, 1, 1), device=device).contiguous(memory_format=torch.channels_last_3d),
5004*da0073e9SAndroid Build Coastguard Worker            torch.channels_last_3d)
5005*da0073e9SAndroid Build Coastguard Worker
5006*da0073e9SAndroid Build Coastguard Worker    # FIXME: make this a elementwise unary and elementwise binary OpInfo test
5007*da0073e9SAndroid Build Coastguard Worker    def test_strides_propagation(self, device):
5008*da0073e9SAndroid Build Coastguard Worker        def _test_helper(x, op, unary=False):
5009*da0073e9SAndroid Build Coastguard Worker            def compare_strides(s1, s2, div):
5010*da0073e9SAndroid Build Coastguard Worker                sdiv = [s // div for s in s1]
5011*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sdiv, s2)
5012*da0073e9SAndroid Build Coastguard Worker
5013*da0073e9SAndroid Build Coastguard Worker            dim = x.dim()
5014*da0073e9SAndroid Build Coastguard Worker            # we produce memory dense outputs, so when input is strided on the last dimension
5015*da0073e9SAndroid Build Coastguard Worker            # we need to divide by that dimension stride to compare input and result strides
5016*da0073e9SAndroid Build Coastguard Worker            div = x.stride(-1)
5017*da0073e9SAndroid Build Coastguard Worker            for p in permutations(range(dim)):
5018*da0073e9SAndroid Build Coastguard Worker                xp = x.permute(p)
5019*da0073e9SAndroid Build Coastguard Worker                if not unary:
5020*da0073e9SAndroid Build Coastguard Worker                    y = torch.randn(xp.size(-1), device=x.device, dtype=x.dtype)
5021*da0073e9SAndroid Build Coastguard Worker                    for inputs in ((xp, xp), (xp, y), (y, xp)):
5022*da0073e9SAndroid Build Coastguard Worker                        res = op(*inputs)
5023*da0073e9SAndroid Build Coastguard Worker                        compare_strides(xp.stride(), res.stride(), div)
5024*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(xp.size(), res.size())
5025*da0073e9SAndroid Build Coastguard Worker                        out = torch.empty(0, device=xp.device, dtype=res.dtype)
5026*da0073e9SAndroid Build Coastguard Worker                        res = op(*inputs, out=out)
5027*da0073e9SAndroid Build Coastguard Worker                        compare_strides(xp.stride(), res.stride(), div)
5028*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(xp.size(), res.size())
5029*da0073e9SAndroid Build Coastguard Worker                else:
5030*da0073e9SAndroid Build Coastguard Worker                    res = op(xp)
5031*da0073e9SAndroid Build Coastguard Worker                    compare_strides(xp.stride(), res.stride(), div)
5032*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(xp.size(), res.size())
5033*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty(0, device=xp.device, dtype=res.dtype)
5034*da0073e9SAndroid Build Coastguard Worker                    res = op(xp, out=out)
5035*da0073e9SAndroid Build Coastguard Worker                    compare_strides(xp.stride(), res.stride(), div)
5036*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(xp.size(), res.size())
5037*da0073e9SAndroid Build Coastguard Worker
5038*da0073e9SAndroid Build Coastguard Worker        # torch.eq by default calls TensorIterator with defined output, torch.add with undefined
5039*da0073e9SAndroid Build Coastguard Worker        binary_ops = (torch.eq, torch.add)
5040*da0073e9SAndroid Build Coastguard Worker        unary_ops = (torch.exp,)
5041*da0073e9SAndroid Build Coastguard Worker        # memory dense, sliced and ambiguous sliced (ambiguous dense loses permutation information)
5042*da0073e9SAndroid Build Coastguard Worker        xs = (torch.randn(2, 3, 4, device=device), torch.randn(2, 3, 8, device=device)[:, :, ::2],
5043*da0073e9SAndroid Build Coastguard Worker              torch.randn(1, 1, 4, 12, device=device)[:, :, :, ::2])
5044*da0073e9SAndroid Build Coastguard Worker        for op in binary_ops:
5045*da0073e9SAndroid Build Coastguard Worker            for x in xs:
5046*da0073e9SAndroid Build Coastguard Worker                _test_helper(x, op)
5047*da0073e9SAndroid Build Coastguard Worker        for op in unary_ops:
5048*da0073e9SAndroid Build Coastguard Worker            for x in xs:
5049*da0073e9SAndroid Build Coastguard Worker                _test_helper(x, op, unary=True)
5050*da0073e9SAndroid Build Coastguard Worker
5051*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5052*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
5053*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("NotImplementedError: PrimTorch does not support pinned memory")
5054*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory_from_constructor(self, device):
5055*da0073e9SAndroid Build Coastguard Worker        def _get_like(t, **kwargs):
5056*da0073e9SAndroid Build Coastguard Worker            return [
5057*da0073e9SAndroid Build Coastguard Worker                torch.rand_like(t, **kwargs),
5058*da0073e9SAndroid Build Coastguard Worker                torch.randn_like(t, **kwargs),
5059*da0073e9SAndroid Build Coastguard Worker                torch.empty_like(t, **kwargs),
5060*da0073e9SAndroid Build Coastguard Worker                torch.full_like(t, 4, **kwargs),
5061*da0073e9SAndroid Build Coastguard Worker                torch.zeros_like(t, **kwargs),
5062*da0073e9SAndroid Build Coastguard Worker                torch.ones_like(t, **kwargs),
5063*da0073e9SAndroid Build Coastguard Worker            ]
5064*da0073e9SAndroid Build Coastguard Worker
5065*da0073e9SAndroid Build Coastguard Worker        def _get_tensors(**kwargs):
5066*da0073e9SAndroid Build Coastguard Worker            return [
5067*da0073e9SAndroid Build Coastguard Worker                torch.tensor([10, 11], **kwargs),
5068*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, **kwargs),
5069*da0073e9SAndroid Build Coastguard Worker                torch.rand(3, **kwargs),
5070*da0073e9SAndroid Build Coastguard Worker                # torch.randint(3, 5, **kwargs), // unsupported
5071*da0073e9SAndroid Build Coastguard Worker                torch.zeros(3, **kwargs),
5072*da0073e9SAndroid Build Coastguard Worker                torch.randperm(3, **kwargs),
5073*da0073e9SAndroid Build Coastguard Worker                torch.empty(6, **kwargs),
5074*da0073e9SAndroid Build Coastguard Worker                torch.ones(6, **kwargs),
5075*da0073e9SAndroid Build Coastguard Worker                torch.eye(6, **kwargs),
5076*da0073e9SAndroid Build Coastguard Worker                torch.arange(3, 5, **kwargs)]
5077*da0073e9SAndroid Build Coastguard Worker
5078*da0073e9SAndroid Build Coastguard Worker        pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True)
5079*da0073e9SAndroid Build Coastguard Worker        for x in pinned_tensors:
5080*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_pinned())
5081*da0073e9SAndroid Build Coastguard Worker
5082*da0073e9SAndroid Build Coastguard Worker        tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True))
5083*da0073e9SAndroid Build Coastguard Worker        for x in tensors:
5084*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_pinned())
5085*da0073e9SAndroid Build Coastguard Worker
5086*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
5087*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5088*da0073e9SAndroid Build Coastguard Worker    def test_storage_all_devices(self, devices):
5089*da0073e9SAndroid Build Coastguard Worker        for device in devices:
5090*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor((), device=device)
5091*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.dtype, t.storage().dtype)
5092*da0073e9SAndroid Build Coastguard Worker
5093*da0073e9SAndroid Build Coastguard Worker    # Note [lazy_clone_ tests with inductor enabled]
5094*da0073e9SAndroid Build Coastguard Worker    # These `lazy_clone_` tests are written in a way that makes them pass in
5095*da0073e9SAndroid Build Coastguard Worker    # both eager mode and compiled mode (`PYTORCH_TEST_WITH_INDUCTOR=1`). There
5096*da0073e9SAndroid Build Coastguard Worker    # are cases where COW tensors can materialize at different times and in
5097*da0073e9SAndroid Build Coastguard Worker    # different ways in compiled mode versus eager mode, and those cases need to
5098*da0073e9SAndroid Build Coastguard Worker    # be avoided. There are two main wrinkles the be aware of.
5099*da0073e9SAndroid Build Coastguard Worker    #
5100*da0073e9SAndroid Build Coastguard Worker    # The first wrinkle is that these tests have to check the internal
5101*da0073e9SAndroid Build Coastguard Worker    # properties of tensors to make sure they materialize in the expected way,
5102*da0073e9SAndroid Build Coastguard Worker    # and those checks cause dynamo graph breaks. Depending on the situation, a
5103*da0073e9SAndroid Build Coastguard Worker    # graph break in-between two compiled graphs that operate on the same COW
5104*da0073e9SAndroid Build Coastguard Worker    # tensor can make the tensor materialize when it would not materialize in
5105*da0073e9SAndroid Build Coastguard Worker    # eager mode, causing the checks to fail. The strategy for avoiding this is
5106*da0073e9SAndroid Build Coastguard Worker    # to make all the operations on COW tensors get compiled into the same
5107*da0073e9SAndroid Build Coastguard Worker    # graph, by not doing any checks between the operations, and just do all the
5108*da0073e9SAndroid Build Coastguard Worker    # checks at the end of the test. If we really do want to perform checks
5109*da0073e9SAndroid Build Coastguard Worker    # between two operations, `op1` and `op2`, the solution is to create two
5110*da0073e9SAndroid Build Coastguard Worker    # different tests. One test performs just `op1` and then checks. The other
5111*da0073e9SAndroid Build Coastguard Worker    # test performs `op1` followed immediately by `op2` and then checks.
5112*da0073e9SAndroid Build Coastguard Worker    #
5113*da0073e9SAndroid Build Coastguard Worker    # The second wrinkle is that in eager mode, if we perform writes on two COW
5114*da0073e9SAndroid Build Coastguard Worker    # tensors where one is a lazy clone of the other, the first tensor to be
5115*da0073e9SAndroid Build Coastguard Worker    # written will be materialized with a new data pointer, and the second
5116*da0073e9SAndroid Build Coastguard Worker    # tensor will just reuse the original data pointer when it is materialized.
5117*da0073e9SAndroid Build Coastguard Worker    # But in compiled mode, if these writes happen in the same graph, the order
5118*da0073e9SAndroid Build Coastguard Worker    # in which the tensors materialize can be different than in eager mode. So
5119*da0073e9SAndroid Build Coastguard Worker    # in this case the strategy is to purposefully cause a graph break to happen
5120*da0073e9SAndroid Build Coastguard Worker    # in-between the two write operations, by adding checks between them, so
5121*da0073e9SAndroid Build Coastguard Worker    # that they have to materialize in the expected order.
5122*da0073e9SAndroid Build Coastguard Worker    @skipXLA
5123*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5124*da0073e9SAndroid Build Coastguard Worker    def test_lazy_clone(self, device, dtype):
5125*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5126*da0073e9SAndroid Build Coastguard Worker        t_orig_storage_addr = torch._C._storage_address(t)
5127*da0073e9SAndroid Build Coastguard Worker        orig_data_ptr = torch._C._data_address(t)
5128*da0073e9SAndroid Build Coastguard Worker        clone = t._lazy_clone()
5129*da0073e9SAndroid Build Coastguard Worker
5130*da0073e9SAndroid Build Coastguard Worker        # Lazy cloning a tensor should cause both it and its clone to become COW
5131*da0073e9SAndroid Build Coastguard Worker        # tensors. They should have different storages, but the same data
5132*da0073e9SAndroid Build Coastguard Worker        # pointer.
5133*da0073e9SAndroid Build Coastguard Worker
5134*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(clone))
5135*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(t))
5136*da0073e9SAndroid Build Coastguard Worker
5137*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5138*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5139*da0073e9SAndroid Build Coastguard Worker
5140*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
5141*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5142*da0073e9SAndroid Build Coastguard Worker
5143*da0073e9SAndroid Build Coastguard Worker    # See Note [lazy_clone_ tests with inductor enabled]
5144*da0073e9SAndroid Build Coastguard Worker    @skipXLA
5145*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5146*da0073e9SAndroid Build Coastguard Worker    def test_lazy_clone_view(self, device, dtype):
5147*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5148*da0073e9SAndroid Build Coastguard Worker        t_orig_storage_addr = torch._C._storage_address(t)
5149*da0073e9SAndroid Build Coastguard Worker        orig_data_ptr = torch._C._data_address(t)
5150*da0073e9SAndroid Build Coastguard Worker        clone = t._lazy_clone()
5151*da0073e9SAndroid Build Coastguard Worker        view = t.view([4])
5152*da0073e9SAndroid Build Coastguard Worker
5153*da0073e9SAndroid Build Coastguard Worker        # Viewing `t` should not cause a copy (materialize) to happen. All the
5154*da0073e9SAndroid Build Coastguard Worker        # tensors should still be COW and have the same data pointer. `view` and
5155*da0073e9SAndroid Build Coastguard Worker        # `t` should have the same storage, and `clone` should have a different
5156*da0073e9SAndroid Build Coastguard Worker        # storage.
5157*da0073e9SAndroid Build Coastguard Worker
5158*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(t))
5159*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(view))
5160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(clone))
5161*da0073e9SAndroid Build Coastguard Worker
5162*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5163*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
5164*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5165*da0073e9SAndroid Build Coastguard Worker
5166*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
5167*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5168*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(view) == orig_data_ptr)
5169*da0073e9SAndroid Build Coastguard Worker
5170*da0073e9SAndroid Build Coastguard Worker    # See Note [lazy_clone_ tests with inductor enabled]
5171*da0073e9SAndroid Build Coastguard Worker    @skipXLA
5172*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5173*da0073e9SAndroid Build Coastguard Worker    def test_lazy_clone_view_materialize(self, device, dtype):
5174*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5175*da0073e9SAndroid Build Coastguard Worker        t_orig_storage_addr = torch._C._storage_address(t)
5176*da0073e9SAndroid Build Coastguard Worker        orig_data_ptr = torch._C._data_address(t)
5177*da0073e9SAndroid Build Coastguard Worker        clone = t._lazy_clone()
5178*da0073e9SAndroid Build Coastguard Worker        view = t.view([4])
5179*da0073e9SAndroid Build Coastguard Worker        view += torch.ones(1, device=device, dtype=dtype)
5180*da0073e9SAndroid Build Coastguard Worker
5181*da0073e9SAndroid Build Coastguard Worker        # Writing to `t` should cause the storage under `t` and `view` to be
5182*da0073e9SAndroid Build Coastguard Worker        # copied (materialized), but should not affect `clone`.
5183*da0073e9SAndroid Build Coastguard Worker
5184*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_cow_tensor(t))
5185*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_cow_tensor(view))
5186*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(clone))
5187*da0073e9SAndroid Build Coastguard Worker
5188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5189*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
5190*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5191*da0073e9SAndroid Build Coastguard Worker
5192*da0073e9SAndroid Build Coastguard Worker        t_new_data_addr = torch._C._data_address(t)
5193*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t_new_data_addr != orig_data_ptr)
5194*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
5195*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5196*da0073e9SAndroid Build Coastguard Worker
5197*da0073e9SAndroid Build Coastguard Worker        clone += torch.ones(1, device=device, dtype=dtype)
5198*da0073e9SAndroid Build Coastguard Worker
5199*da0073e9SAndroid Build Coastguard Worker        # Writing to `clone` should materialize it, so it should no longer
5200*da0073e9SAndroid Build Coastguard Worker        # be COW. However, since `clone`'s storage is the only COW storage
5201*da0073e9SAndroid Build Coastguard Worker        # left that holds a reference to the original data pointer, this
5202*da0073e9SAndroid Build Coastguard Worker        # materialization should not actually cause a copy--it should
5203*da0073e9SAndroid Build Coastguard Worker        # just reuse the original data pointer.
5204*da0073e9SAndroid Build Coastguard Worker
5205*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_cow_tensor(t))
5206*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_cow_tensor(view))
5207*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_cow_tensor(clone))
5208*da0073e9SAndroid Build Coastguard Worker
5209*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
5210*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
5211*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
5212*da0073e9SAndroid Build Coastguard Worker
5213*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(t) == t_new_data_addr)
5214*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
5215*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
5216*da0073e9SAndroid Build Coastguard Worker
5217*da0073e9SAndroid Build Coastguard Worker    @skipXLA
5218*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5219*da0073e9SAndroid Build Coastguard Worker    def test_lazy_clone_binary_op_no_materialize(self, device, dtype):
5220*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5221*da0073e9SAndroid Build Coastguard Worker        clone = t._lazy_clone()
5222*da0073e9SAndroid Build Coastguard Worker        res = t + clone
5223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(t))
5224*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_cow_tensor(clone))
5225*da0073e9SAndroid Build Coastguard Worker
5226*da0073e9SAndroid Build Coastguard Worker    # This tests that if a COW materialization is attempted inside an
5227*da0073e9SAndroid Build Coastguard Worker    # `at::parallel_for` loop function, then an error is raised. This test is
5228*da0073e9SAndroid Build Coastguard Worker    # implemented in Python rather than C++ because the C++ tests are built
5229*da0073e9SAndroid Build Coastguard Worker    # without multithreading support in `at::parallel_for`.
5230*da0073e9SAndroid Build Coastguard Worker    @skipXLA
5231*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Torchdynamo fails and we do not need to test it here anyway")
5232*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5233*da0073e9SAndroid Build Coastguard Worker    def test_parallel_cow_materialize_error(self, device, dtype):
5234*da0073e9SAndroid Build Coastguard Worker
5235*da0073e9SAndroid Build Coastguard Worker        def run(num_threads, num_parallel, skip_first, should_error):
5236*da0073e9SAndroid Build Coastguard Worker            orig_num_threads = torch.get_num_threads()
5237*da0073e9SAndroid Build Coastguard Worker
5238*da0073e9SAndroid Build Coastguard Worker            try:
5239*da0073e9SAndroid Build Coastguard Worker                torch.set_num_threads(num_threads)
5240*da0073e9SAndroid Build Coastguard Worker
5241*da0073e9SAndroid Build Coastguard Worker                a = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)._lazy_clone()
5242*da0073e9SAndroid Build Coastguard Worker
5243*da0073e9SAndroid Build Coastguard Worker                if should_error:
5244*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, r'Materializing a storage'):
5245*da0073e9SAndroid Build Coastguard Worker                        torch._test_parallel_materialize(
5246*da0073e9SAndroid Build Coastguard Worker                            a, num_parallel, skip_first)
5247*da0073e9SAndroid Build Coastguard Worker                else:
5248*da0073e9SAndroid Build Coastguard Worker                    torch._test_parallel_materialize(a, num_parallel, skip_first)
5249*da0073e9SAndroid Build Coastguard Worker
5250*da0073e9SAndroid Build Coastguard Worker                # Error should not raise in any case if the tensor is not COW
5251*da0073e9SAndroid Build Coastguard Worker                b = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
5252*da0073e9SAndroid Build Coastguard Worker                torch._test_parallel_materialize(b, num_parallel, skip_first)
5253*da0073e9SAndroid Build Coastguard Worker
5254*da0073e9SAndroid Build Coastguard Worker            finally:
5255*da0073e9SAndroid Build Coastguard Worker                torch.set_num_threads(orig_num_threads)
5256*da0073e9SAndroid Build Coastguard Worker
5257*da0073e9SAndroid Build Coastguard Worker        run(1, 1, False, True)
5258*da0073e9SAndroid Build Coastguard Worker        run(1, 1, True, False)
5259*da0073e9SAndroid Build Coastguard Worker        run(1, 10, False, True)
5260*da0073e9SAndroid Build Coastguard Worker        run(1, 10, True, True)
5261*da0073e9SAndroid Build Coastguard Worker        run(10, 1, False, True)
5262*da0073e9SAndroid Build Coastguard Worker        run(10, 1, True, False)
5263*da0073e9SAndroid Build Coastguard Worker        run(10, 10, False, True)
5264*da0073e9SAndroid Build Coastguard Worker        run(10, 10, True, True)
5265*da0073e9SAndroid Build Coastguard Worker        run(10, 2, False, True)
5266*da0073e9SAndroid Build Coastguard Worker        run(10, 2, True, True)
5267*da0073e9SAndroid Build Coastguard Worker
5268*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
5269*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
5270*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.double, torch.half)
5271*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5272*da0073e9SAndroid Build Coastguard Worker    def test_multinomial(self, device, dtype):
5273*da0073e9SAndroid Build Coastguard Worker        def make_prob_dist(shape, is_contiguous):
5274*da0073e9SAndroid Build Coastguard Worker            if is_contiguous:
5275*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half:
5276*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros(shape, device=device).uniform_().to(dtype=torch.half)
5277*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(shape, device=device, dtype=dtype).uniform_()
5278*da0073e9SAndroid Build Coastguard Worker            elif len(shape) == 1:
5279*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half:
5280*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=torch.half)[:, 2]
5281*da0073e9SAndroid Build Coastguard Worker                return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2]
5282*da0073e9SAndroid Build Coastguard Worker            else:
5283*da0073e9SAndroid Build Coastguard Worker                # num dim = 2
5284*da0073e9SAndroid Build Coastguard Worker                new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
5285*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half:
5286*da0073e9SAndroid Build Coastguard Worker                    prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=torch.half)
5287*da0073e9SAndroid Build Coastguard Worker                else:
5288*da0073e9SAndroid Build Coastguard Worker                    prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_()
5289*da0073e9SAndroid Build Coastguard Worker                prob_dist = prob_dist.transpose(1, 4)
5290*da0073e9SAndroid Build Coastguard Worker                prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
5291*da0073e9SAndroid Build Coastguard Worker                assert not prob_dist.is_contiguous()  # sanity check
5292*da0073e9SAndroid Build Coastguard Worker                return prob_dist
5293*da0073e9SAndroid Build Coastguard Worker
5294*da0073e9SAndroid Build Coastguard Worker        for is_contiguous in (True, False):
5295*da0073e9SAndroid Build Coastguard Worker            # with replacement
5296*da0073e9SAndroid Build Coastguard Worker            n_row = 3
5297*da0073e9SAndroid Build Coastguard Worker            for n_col in range(4, 5 + 1):
5298*da0073e9SAndroid Build Coastguard Worker                prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
5299*da0073e9SAndroid Build Coastguard Worker                # indices that shouldn't be sampled (<0 means none)
5300*da0073e9SAndroid Build Coastguard Worker                zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist()
5301*da0073e9SAndroid Build Coastguard Worker                for i, j in enumerate(zero_prob_indices):
5302*da0073e9SAndroid Build Coastguard Worker                    if j >= 0:
5303*da0073e9SAndroid Build Coastguard Worker                        prob_dist[i, j] = 0
5304*da0073e9SAndroid Build Coastguard Worker                n_sample = n_col * 3
5305*da0073e9SAndroid Build Coastguard Worker                sample_indices = torch.multinomial(prob_dist, n_sample, True)
5306*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(prob_dist.dim(), 2)
5307*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sample_indices.size(1), n_sample)
5308*da0073e9SAndroid Build Coastguard Worker                for i in range(n_row):
5309*da0073e9SAndroid Build Coastguard Worker                    zero_prob_idx = zero_prob_indices[i]
5310*da0073e9SAndroid Build Coastguard Worker                    if zero_prob_idx < 0:
5311*da0073e9SAndroid Build Coastguard Worker                        continue
5312*da0073e9SAndroid Build Coastguard Worker                    for j in range(n_sample):
5313*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(sample_indices[i, j], zero_prob_idx,
5314*da0073e9SAndroid Build Coastguard Worker                                            msg="sampled an index with zero probability")
5315*da0073e9SAndroid Build Coastguard Worker
5316*da0073e9SAndroid Build Coastguard Worker            # without replacement
5317*da0073e9SAndroid Build Coastguard Worker            n_row = 3
5318*da0073e9SAndroid Build Coastguard Worker            for n_col in range(2, 10 + 1, 2):
5319*da0073e9SAndroid Build Coastguard Worker                prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
5320*da0073e9SAndroid Build Coastguard Worker                # indices that shouldn't be sampled (<0 means none)
5321*da0073e9SAndroid Build Coastguard Worker                zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist()
5322*da0073e9SAndroid Build Coastguard Worker                for i, j in enumerate(zero_prob_indices):
5323*da0073e9SAndroid Build Coastguard Worker                    if j >= 0:
5324*da0073e9SAndroid Build Coastguard Worker                        prob_dist[i, j] = 0
5325*da0073e9SAndroid Build Coastguard Worker                n_sample = max(1, n_col - 2)
5326*da0073e9SAndroid Build Coastguard Worker                sample_indices = torch.multinomial(prob_dist, n_sample, False)
5327*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(prob_dist.dim(), 2)
5328*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sample_indices.size(1), n_sample)
5329*da0073e9SAndroid Build Coastguard Worker                for i in range(n_row):
5330*da0073e9SAndroid Build Coastguard Worker                    row_samples = {}
5331*da0073e9SAndroid Build Coastguard Worker                    zero_prob_idx = zero_prob_indices[i]
5332*da0073e9SAndroid Build Coastguard Worker                    for j in range(n_sample):
5333*da0073e9SAndroid Build Coastguard Worker                        sample_idx = sample_indices[i, j]
5334*da0073e9SAndroid Build Coastguard Worker                        if zero_prob_idx >= 0:
5335*da0073e9SAndroid Build Coastguard Worker                            self.assertNotEqual(sample_idx, zero_prob_idx,
5336*da0073e9SAndroid Build Coastguard Worker                                                msg="sampled an index with zero probability")
5337*da0073e9SAndroid Build Coastguard Worker                        self.assertNotIn(sample_idx, row_samples, "sampled an index twice")
5338*da0073e9SAndroid Build Coastguard Worker                        row_samples[sample_idx] = True
5339*da0073e9SAndroid Build Coastguard Worker
5340*da0073e9SAndroid Build Coastguard Worker            # vector
5341*da0073e9SAndroid Build Coastguard Worker            n_col = 4
5342*da0073e9SAndroid Build Coastguard Worker            prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1)
5343*da0073e9SAndroid Build Coastguard Worker            zero_prob_idx = 1  # index that shouldn't be sampled
5344*da0073e9SAndroid Build Coastguard Worker            prob_dist[zero_prob_idx] = 0
5345*da0073e9SAndroid Build Coastguard Worker            n_sample = 20
5346*da0073e9SAndroid Build Coastguard Worker            sample_indices = torch.multinomial(prob_dist, n_sample, True)
5347*da0073e9SAndroid Build Coastguard Worker            for sample_index in sample_indices:
5348*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(sample_index, zero_prob_idx, msg="sampled an index with zero probability")
5349*da0073e9SAndroid Build Coastguard Worker            s_dim = sample_indices.dim()
5350*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample_indices.dim(), 1, msg="wrong number of dimensions")
5351*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(prob_dist.dim(), 1, msg="wrong number of prob_dist dimensions")
5352*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample_indices.size(0), n_sample, msg="wrong number of samples")
5353*da0073e9SAndroid Build Coastguard Worker
5354*da0073e9SAndroid Build Coastguard Worker        # CUDA misalignment issue (#46702)
5355*da0073e9SAndroid Build Coastguard Worker        n_row, n_col = 2, 3
5356*da0073e9SAndroid Build Coastguard Worker        prob_dist = make_prob_dist([n_row, n_col], True)
5357*da0073e9SAndroid Build Coastguard Worker        n_sample = 1
5358*da0073e9SAndroid Build Coastguard Worker        sample_indices = torch.multinomial(prob_dist, n_sample, True)
5359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sample_indices.dim(), 2, msg="wrong number of dimensions")
5360*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sample_indices.size(1), n_sample, msg="wrong number of samples")
5361*da0073e9SAndroid Build Coastguard Worker
5362*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
5363*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5364*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.half)
5365*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_deterministic(self, device, dtype):
5366*da0073e9SAndroid Build Coastguard Worker        gen = torch.Generator(device=device)
5367*da0073e9SAndroid Build Coastguard Worker
5368*da0073e9SAndroid Build Coastguard Worker        trials = 5
5369*da0073e9SAndroid Build Coastguard Worker        seed = 0
5370*da0073e9SAndroid Build Coastguard Worker        prob_dist = torch.rand(10000, 1000, device=device, dtype=dtype)
5371*da0073e9SAndroid Build Coastguard Worker        n_sample = 1
5372*da0073e9SAndroid Build Coastguard Worker
5373*da0073e9SAndroid Build Coastguard Worker        for i in range(trials):
5374*da0073e9SAndroid Build Coastguard Worker            gen.manual_seed(seed)
5375*da0073e9SAndroid Build Coastguard Worker            samples_1 = torch.multinomial(prob_dist, n_sample, True, generator=gen)
5376*da0073e9SAndroid Build Coastguard Worker
5377*da0073e9SAndroid Build Coastguard Worker            gen.manual_seed(seed)
5378*da0073e9SAndroid Build Coastguard Worker            samples_2 = torch.multinomial(prob_dist, n_sample, True, generator=gen)
5379*da0073e9SAndroid Build Coastguard Worker
5380*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(samples_1, samples_2)
5381*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(samples_1.dim(), 2, msg="wrong number of dimensions")
5382*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(samples_1.size(1), n_sample, msg="wrong number of samples")
5383*da0073e9SAndroid Build Coastguard Worker
5384*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
5385*da0073e9SAndroid Build Coastguard Worker    @slowTest
5386*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
5387*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_rng_state_advance(self, device, dtype):
5388*da0073e9SAndroid Build Coastguard Worker        corpus_size = 100000
5389*da0073e9SAndroid Build Coastguard Worker        freqs = torch.ones(corpus_size, dtype=torch.float, device=device)
5390*da0073e9SAndroid Build Coastguard Worker        n_sample = 100
5391*da0073e9SAndroid Build Coastguard Worker        samples1 = torch.multinomial(freqs, n_sample, replacement=True)
5392*da0073e9SAndroid Build Coastguard Worker        samples2 = torch.multinomial(freqs, n_sample, replacement=True)
5393*da0073e9SAndroid Build Coastguard Worker        samples = torch.cat([samples1, samples2])
5394*da0073e9SAndroid Build Coastguard Worker        # expect no more than 1 repeating elements generated in 2 attempts
5395*da0073e9SAndroid Build Coastguard Worker        # the probability of at least element being repeated is surprisingly large, 18%
5396*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(2 * n_sample - samples.unique().size(0), 2)
5397*da0073e9SAndroid Build Coastguard Worker        samples1 = torch.multinomial(freqs, n_sample, replacement=False)
5398*da0073e9SAndroid Build Coastguard Worker        samples2 = torch.multinomial(freqs, n_sample, replacement=False)
5399*da0073e9SAndroid Build Coastguard Worker        samples = torch.cat([samples1, samples2])
5400*da0073e9SAndroid Build Coastguard Worker        # expect no more than 1 repeating elements generated in 2 attempts
5401*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1)
5402*da0073e9SAndroid Build Coastguard Worker
5403*da0073e9SAndroid Build Coastguard Worker    def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
5404*da0073e9SAndroid Build Coastguard Worker                                            memory_format, compare_data=True, default_is_preserve=False):
5405*da0073e9SAndroid Build Coastguard Worker
5406*da0073e9SAndroid Build Coastguard Worker        assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
5407*da0073e9SAndroid Build Coastguard Worker
5408*da0073e9SAndroid Build Coastguard Worker        # xc is a channels last tensor
5409*da0073e9SAndroid Build Coastguard Worker        xc = input_generator_fn(device)
5410*da0073e9SAndroid Build Coastguard Worker        # xc is not memory dense, but looks like channels last
5411*da0073e9SAndroid Build Coastguard Worker        # We don't preserve non-dense striding
5412*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_TORCHINDUCTOR:
5413*da0073e9SAndroid Build Coastguard Worker            if memory_format == torch.channels_last:
5414*da0073e9SAndroid Build Coastguard Worker                xc = xc[..., ::2, ::2]
5415*da0073e9SAndroid Build Coastguard Worker            else:
5416*da0073e9SAndroid Build Coastguard Worker                xc = xc[..., ::2, ::2, ::2]
5417*da0073e9SAndroid Build Coastguard Worker
5418*da0073e9SAndroid Build Coastguard Worker        clone = transformation_fn(xc, memory_format=torch.preserve_format)
5419*da0073e9SAndroid Build Coastguard Worker
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(clone.is_contiguous())
5422*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(clone.is_contiguous(memory_format=memory_format))
5423*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_TORCHINDUCTOR:
5424*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(xc.is_contiguous())
5425*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(xc.is_contiguous(memory_format=memory_format))
5426*da0073e9SAndroid Build Coastguard Worker        if compare_data:
5427*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xc, clone.to(xc))
5428*da0073e9SAndroid Build Coastguard Worker
5429*da0073e9SAndroid Build Coastguard Worker        xc = input_generator_fn(device)
5430*da0073e9SAndroid Build Coastguard Worker        clone = transformation_fn(xc, memory_format=torch.contiguous_format)
5431*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(clone.is_contiguous())
5432*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(clone.is_contiguous(memory_format=memory_format))
5433*da0073e9SAndroid Build Coastguard Worker        if compare_data:
5434*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xc, clone.to(xc))
5435*da0073e9SAndroid Build Coastguard Worker
5436*da0073e9SAndroid Build Coastguard Worker        xc = input_generator_fn(device)
5437*da0073e9SAndroid Build Coastguard Worker        clone = transformation_fn(xc)
5438*da0073e9SAndroid Build Coastguard Worker
5439*da0073e9SAndroid Build Coastguard Worker        if default_is_preserve:
5440*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(clone.is_contiguous())
5441*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(clone.is_contiguous(memory_format=memory_format))
5442*da0073e9SAndroid Build Coastguard Worker        else:
5443*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(clone.is_contiguous())
5444*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(clone.is_contiguous(memory_format=memory_format))
5445*da0073e9SAndroid Build Coastguard Worker        if compare_data:
5446*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xc, clone.to(xc))
5447*da0073e9SAndroid Build Coastguard Worker
5448*da0073e9SAndroid Build Coastguard Worker        # TODO copy _like constructors to stride permutation instead of just layout
5449*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_TORCHINDUCTOR:
5450*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
5451*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
5452*da0073e9SAndroid Build Coastguard Worker                permutation = list(range(len(x.shape)))
5453*da0073e9SAndroid Build Coastguard Worker                random.shuffle(permutation)
5454*da0073e9SAndroid Build Coastguard Worker                x = x.permute(permutation)
5455*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
5456*da0073e9SAndroid Build Coastguard Worker
5457*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_to(self, device):
5458*da0073e9SAndroid Build Coastguard Worker        def get_generator(memory_format, shape):
5459*da0073e9SAndroid Build Coastguard Worker            def input_generator_fn(device):
5460*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5461*da0073e9SAndroid Build Coastguard Worker            return input_generator_fn
5462*da0073e9SAndroid Build Coastguard Worker
5463*da0073e9SAndroid Build Coastguard Worker        def transformation_fn(tensor, **kwargs):
5464*da0073e9SAndroid Build Coastguard Worker            return tensor.to(dtype=torch.float64, **kwargs)
5465*da0073e9SAndroid Build Coastguard Worker
5466*da0073e9SAndroid Build Coastguard Worker        formats_shapes = (
5467*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (4, 3, 8, 8)),
5468*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5469*da0073e9SAndroid Build Coastguard Worker
5470*da0073e9SAndroid Build Coastguard Worker        for mf, shape in formats_shapes:
5471*da0073e9SAndroid Build Coastguard Worker            self._test_memory_format_transformations(
5472*da0073e9SAndroid Build Coastguard Worker                device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
5473*da0073e9SAndroid Build Coastguard Worker
5474*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_type(self, device):
5475*da0073e9SAndroid Build Coastguard Worker        def get_generator(memory_format, shape):
5476*da0073e9SAndroid Build Coastguard Worker            def input_generator_fn(device):
5477*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5478*da0073e9SAndroid Build Coastguard Worker            return input_generator_fn
5479*da0073e9SAndroid Build Coastguard Worker
5480*da0073e9SAndroid Build Coastguard Worker        def transformation_fn(tensor, **kwargs):
5481*da0073e9SAndroid Build Coastguard Worker            return tensor.to(torch.float64, **kwargs)
5482*da0073e9SAndroid Build Coastguard Worker
5483*da0073e9SAndroid Build Coastguard Worker        formats_shapes = (
5484*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (4, 3, 8, 8)),
5485*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5486*da0073e9SAndroid Build Coastguard Worker
5487*da0073e9SAndroid Build Coastguard Worker        for mf, shape in formats_shapes:
5488*da0073e9SAndroid Build Coastguard Worker            self._test_memory_format_transformations(
5489*da0073e9SAndroid Build Coastguard Worker                device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
5490*da0073e9SAndroid Build Coastguard Worker
5491*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_clone(self, device):
5492*da0073e9SAndroid Build Coastguard Worker        def get_generator(memory_format, shape):
5493*da0073e9SAndroid Build Coastguard Worker            def input_generator_fn(device):
5494*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5495*da0073e9SAndroid Build Coastguard Worker            return input_generator_fn
5496*da0073e9SAndroid Build Coastguard Worker
5497*da0073e9SAndroid Build Coastguard Worker        def transformation_fn(tensor, **kwargs):
5498*da0073e9SAndroid Build Coastguard Worker            return tensor.clone(**kwargs)
5499*da0073e9SAndroid Build Coastguard Worker
5500*da0073e9SAndroid Build Coastguard Worker        formats_shapes = (
5501*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (4, 3, 8, 8)),
5502*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5503*da0073e9SAndroid Build Coastguard Worker
5504*da0073e9SAndroid Build Coastguard Worker        for mf, shape in formats_shapes:
5505*da0073e9SAndroid Build Coastguard Worker            self._test_memory_format_transformations(
5506*da0073e9SAndroid Build Coastguard Worker                device, get_generator(mf, shape), transformation_fn, mf, True, default_is_preserve=True)
5507*da0073e9SAndroid Build Coastguard Worker
5508*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_factory_like_functions_preserve(self, device):
5509*da0073e9SAndroid Build Coastguard Worker        def get_generator(memory_format, shape):
5510*da0073e9SAndroid Build Coastguard Worker            def input_generator_fn(device):
5511*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5512*da0073e9SAndroid Build Coastguard Worker            return input_generator_fn
5513*da0073e9SAndroid Build Coastguard Worker
5514*da0073e9SAndroid Build Coastguard Worker        transformation_fns = [
5515*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.zeros_like(t, **kwargs),
5516*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.ones_like(t, **kwargs),
5517*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs),
5518*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.randint_like(t, 100, **kwargs),
5519*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.randn_like(t, **kwargs),
5520*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.rand_like(t, **kwargs),
5521*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.full_like(t, 7, **kwargs),
5522*da0073e9SAndroid Build Coastguard Worker            lambda t, **kwargs: torch.empty_like(t, **kwargs)]
5523*da0073e9SAndroid Build Coastguard Worker
5524*da0073e9SAndroid Build Coastguard Worker        formats_shapes = (
5525*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (4, 3, 8, 8)),
5526*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5527*da0073e9SAndroid Build Coastguard Worker
5528*da0073e9SAndroid Build Coastguard Worker        for mf, shape, in formats_shapes:
5529*da0073e9SAndroid Build Coastguard Worker            for transformation_fn in transformation_fns:
5530*da0073e9SAndroid Build Coastguard Worker                self._test_memory_format_transformations(
5531*da0073e9SAndroid Build Coastguard Worker                    device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True)
5532*da0073e9SAndroid Build Coastguard Worker
5533*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_type_shortcuts(self, device):
5534*da0073e9SAndroid Build Coastguard Worker        def get_generator(memory_format, shape, dtype):
5535*da0073e9SAndroid Build Coastguard Worker            def input_generator_fn(device):
5536*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=dtype).clamp(0, 1) \
5537*da0073e9SAndroid Build Coastguard Worker                    .round().contiguous(memory_format=memory_format)
5538*da0073e9SAndroid Build Coastguard Worker            return input_generator_fn
5539*da0073e9SAndroid Build Coastguard Worker
5540*da0073e9SAndroid Build Coastguard Worker
5541*da0073e9SAndroid Build Coastguard Worker        def get_fn(fn_name):
5542*da0073e9SAndroid Build Coastguard Worker            def transformation_fn(tensor, **kwargs):
5543*da0073e9SAndroid Build Coastguard Worker                fn = getattr(tensor, fn_name)
5544*da0073e9SAndroid Build Coastguard Worker                return fn(**kwargs)
5545*da0073e9SAndroid Build Coastguard Worker            return transformation_fn
5546*da0073e9SAndroid Build Coastguard Worker
5547*da0073e9SAndroid Build Coastguard Worker        shortcuts = ['byte', 'char', 'double', 'bool', 'half', 'int', 'long', 'short']
5548*da0073e9SAndroid Build Coastguard Worker        if device == 'cpu':
5549*da0073e9SAndroid Build Coastguard Worker            shortcuts += ['bfloat16']
5550*da0073e9SAndroid Build Coastguard Worker
5551*da0073e9SAndroid Build Coastguard Worker        formats_shapes = (
5552*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (4, 3, 8, 8)),
5553*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5554*da0073e9SAndroid Build Coastguard Worker
5555*da0073e9SAndroid Build Coastguard Worker        for mf, shape in formats_shapes:
5556*da0073e9SAndroid Build Coastguard Worker            for fn_name in shortcuts:
5557*da0073e9SAndroid Build Coastguard Worker                self._test_memory_format_transformations(
5558*da0073e9SAndroid Build Coastguard Worker                    device, get_generator(mf, shape, torch.float32), get_fn(fn_name), mf, default_is_preserve=True)
5559*da0073e9SAndroid Build Coastguard Worker
5560*da0073e9SAndroid Build Coastguard Worker        # Test 'float' separately to avoid float->float no-op.
5561*da0073e9SAndroid Build Coastguard Worker        for mf, shape in formats_shapes:
5562*da0073e9SAndroid Build Coastguard Worker            self._test_memory_format_transformations(
5563*da0073e9SAndroid Build Coastguard Worker                device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True)
5564*da0073e9SAndroid Build Coastguard Worker
5565*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
5566*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_cpu_and_cuda_ops(self, device):
5567*da0073e9SAndroid Build Coastguard Worker        def get_generator(memory_format, shape):
5568*da0073e9SAndroid Build Coastguard Worker            def input_generator_fn(device):
5569*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
5570*da0073e9SAndroid Build Coastguard Worker            return input_generator_fn
5571*da0073e9SAndroid Build Coastguard Worker
5572*da0073e9SAndroid Build Coastguard Worker        def transformation_cpu_fn(tensor, **kwargs):
5573*da0073e9SAndroid Build Coastguard Worker            return tensor.cpu(**kwargs)
5574*da0073e9SAndroid Build Coastguard Worker
5575*da0073e9SAndroid Build Coastguard Worker        def transformation_cuda_fn(tensor, **kwargs):
5576*da0073e9SAndroid Build Coastguard Worker            return tensor.cuda(**kwargs)
5577*da0073e9SAndroid Build Coastguard Worker
5578*da0073e9SAndroid Build Coastguard Worker        formats_shapes = (
5579*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (4, 3, 8, 8)),
5580*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (4, 3, 8, 8, 8)))
5581*da0073e9SAndroid Build Coastguard Worker
5582*da0073e9SAndroid Build Coastguard Worker        for mf, shape in formats_shapes:
5583*da0073e9SAndroid Build Coastguard Worker            self._test_memory_format_transformations(
5584*da0073e9SAndroid Build Coastguard Worker                'cuda', get_generator(mf, shape), transformation_cpu_fn, mf, default_is_preserve=True)
5585*da0073e9SAndroid Build Coastguard Worker            self._test_memory_format_transformations(
5586*da0073e9SAndroid Build Coastguard Worker                'cpu', get_generator(mf, shape), transformation_cuda_fn, mf, default_is_preserve=True)
5587*da0073e9SAndroid Build Coastguard Worker
5588*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_serialization
5589*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5590*da0073e9SAndroid Build Coastguard Worker    def test_pickle_gradscaler(self, device):
5591*da0073e9SAndroid Build Coastguard Worker        # This test should pass in 3 cases for cuda:
5592*da0073e9SAndroid Build Coastguard Worker        #  1. cuda is not available.
5593*da0073e9SAndroid Build Coastguard Worker        #  2. cuda is available but device is not cuda.
5594*da0073e9SAndroid Build Coastguard Worker        #  3. cuda is available and device is cuda.
5595*da0073e9SAndroid Build Coastguard Worker        # In case 1, a and b disable themselves on construction and shouldn't try to pickle workhorse attributes.
5596*da0073e9SAndroid Build Coastguard Worker        # In case 2, a and b are enabled.  Workhorse attributes participate in pickling, but none are lazy-inited
5597*da0073e9SAndroid Build Coastguard Worker        # to cuda Tensors, because I don't want to do cuda things if device is not cuda.
5598*da0073e9SAndroid Build Coastguard Worker        # In case 3, a and b are enabled and we may also try lazy-initing _scale to a cuda tensor.
5599*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5600*da0073e9SAndroid Build Coastguard Worker        try_lazy_inits = (True, False)
5601*da0073e9SAndroid Build Coastguard Worker        GradScaler = partial(torch.GradScaler, device=device.type)
5602*da0073e9SAndroid Build Coastguard Worker        for lazy_init_scale in try_lazy_inits:
5603*da0073e9SAndroid Build Coastguard Worker            a = GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2)
5604*da0073e9SAndroid Build Coastguard Worker            if device.type == "cuda":
5605*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(not a.is_enabled() if torch.cuda.amp.common.amp_definitely_not_available() else a.is_enabled())
5606*da0073e9SAndroid Build Coastguard Worker            else:
5607*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(a.is_enabled())
5608*da0073e9SAndroid Build Coastguard Worker            if lazy_init_scale:
5609*da0073e9SAndroid Build Coastguard Worker                # Dummy a.scale() call lazy-inits a._scale Tensor.
5610*da0073e9SAndroid Build Coastguard Worker                a.scale(torch.tensor([4.0], dtype=torch.float32, device=device))
5611*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(a._scale.device.type == device.type)
5612*da0073e9SAndroid Build Coastguard Worker            # The following three lines should work whether or not cuda is available.
5613*da0073e9SAndroid Build Coastguard Worker            serialized = pickle.dumps(a)
5614*da0073e9SAndroid Build Coastguard Worker            b = pickle.loads(serialized)
5615*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b.is_enabled(), a.is_enabled())
5616*da0073e9SAndroid Build Coastguard Worker            if a.is_enabled():
5617*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.get_scale(), 3.)
5618*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.get_growth_factor(), 4.)
5619*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.get_backoff_factor(), .5)
5620*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.get_growth_interval(), 2)
5621*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b._init_growth_tracker, 0)
5622*da0073e9SAndroid Build Coastguard Worker                # supplies a dummy key to test the defaultdict's default_factory
5623*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b._per_optimizer_states["fdsa"],
5624*da0073e9SAndroid Build Coastguard Worker                                 torch.amp.grad_scaler._refresh_per_optimizer_state())
5625*da0073e9SAndroid Build Coastguard Worker                if lazy_init_scale:
5626*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0)
5627*da0073e9SAndroid Build Coastguard Worker
5628*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
5629*da0073e9SAndroid Build Coastguard Worker    def _test_multinomial_empty(self, device, replacement, num_samples):
5630*da0073e9SAndroid Build Coastguard Worker        probs = torch.ones(0, 3, device=device)
5631*da0073e9SAndroid Build Coastguard Worker        expected = torch.empty(0, num_samples, dtype=torch.int64)
5632*da0073e9SAndroid Build Coastguard Worker        out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
5633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
5634*da0073e9SAndroid Build Coastguard Worker
5635*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
5636*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_empty_w_replacement(self, device):
5637*da0073e9SAndroid Build Coastguard Worker        self._test_multinomial_empty(device, True, 1)
5638*da0073e9SAndroid Build Coastguard Worker        self._test_multinomial_empty(device, True, 2)
5639*da0073e9SAndroid Build Coastguard Worker
5640*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test distributions
5641*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_empty_wo_replacement(self, device):
5642*da0073e9SAndroid Build Coastguard Worker        self._test_multinomial_empty(device, False, 1)
5643*da0073e9SAndroid Build Coastguard Worker        self._test_multinomial_empty(device, False, 2)
5644*da0073e9SAndroid Build Coastguard Worker
5645*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5646*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
5647*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_unscale(self, device, dtype):
5648*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5649*da0073e9SAndroid Build Coastguard Worker        device0 = "cuda:0" if device.type == "cuda" else "cpu"
5650*da0073e9SAndroid Build Coastguard Worker        inv_scale = torch.full((1,), 0.25, dtype=torch.float, device=device0)
5651*da0073e9SAndroid Build Coastguard Worker        found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device0)
5652*da0073e9SAndroid Build Coastguard Worker
5653*da0073e9SAndroid Build Coastguard Worker        size = 20
5654*da0073e9SAndroid Build Coastguard Worker        g = torch.full((size, size), 4.0, dtype=dtype, device=device0)
5655*da0073e9SAndroid Build Coastguard Worker        ginf = g.clone()
5656*da0073e9SAndroid Build Coastguard Worker        ginf[2, 2] = float('inf')
5657*da0073e9SAndroid Build Coastguard Worker        gnan = g.clone()
5658*da0073e9SAndroid Build Coastguard Worker        gnan[2, 2] = float('nan')
5659*da0073e9SAndroid Build Coastguard Worker
5660*da0073e9SAndroid Build Coastguard Worker        # Tries selected combinations of
5661*da0073e9SAndroid Build Coastguard Worker        #  - contiguous grads
5662*da0073e9SAndroid Build Coastguard Worker        #  - g.clone().t() which is not contiguous but still non overlapping and dense
5663*da0073e9SAndroid Build Coastguard Worker        #  - variants of g.clone()[:, :5] which are not non overlapping and dense
5664*da0073e9SAndroid Build Coastguard Worker        # Non overlapping and dense grads route into a multi tensor apply kernel,
5665*da0073e9SAndroid Build Coastguard Worker        # others use a fallback per-tensor kernel, so we should try both.
5666*da0073e9SAndroid Build Coastguard Worker        cases = (
5667*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), g.clone()], False),
5668*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), g.clone().t()], False),
5669*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), g.clone()[:, :5]], False),
5670*da0073e9SAndroid Build Coastguard Worker            ([g.clone()[:, :5], g.clone()[:, :5]], False),
5671*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), ginf.clone()], True),
5672*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), gnan.clone()], True),
5673*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), ginf.clone()[:, :5]], True),
5674*da0073e9SAndroid Build Coastguard Worker            ([g.clone(), gnan.clone()[:, :5]], True),
5675*da0073e9SAndroid Build Coastguard Worker            ([ginf.clone(), g.clone()[:, :5]], True),
5676*da0073e9SAndroid Build Coastguard Worker            ([ginf.clone()[:, :5], g.clone()[:, :5]], True),
5677*da0073e9SAndroid Build Coastguard Worker        )
5678*da0073e9SAndroid Build Coastguard Worker
5679*da0073e9SAndroid Build Coastguard Worker        for grads, has_inf in cases:
5680*da0073e9SAndroid Build Coastguard Worker            found_inf.zero_()
5681*da0073e9SAndroid Build Coastguard Worker            torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
5682*da0073e9SAndroid Build Coastguard Worker            if has_inf:
5683*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(found_inf, 1.0)
5684*da0073e9SAndroid Build Coastguard Worker            else:
5685*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(found_inf, 0.0)
5686*da0073e9SAndroid Build Coastguard Worker                for grad in grads:
5687*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
5688*da0073e9SAndroid Build Coastguard Worker
5689*da0073e9SAndroid Build Coastguard Worker        # When passing lists with mismatched dtypes to a raw
5690*da0073e9SAndroid Build Coastguard Worker        # _amp_foreach_non_finite_check_and_unscale_ call on CUDA,
5691*da0073e9SAndroid Build Coastguard Worker        # it's expected to fall back to single-tensor TensorIterator kernel.
5692*da0073e9SAndroid Build Coastguard Worker        grads = [g.clone(), g.to(dtype=torch.float16)]
5693*da0073e9SAndroid Build Coastguard Worker        torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
5694*da0073e9SAndroid Build Coastguard Worker        for grad in grads:
5695*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
5696*da0073e9SAndroid Build Coastguard Worker
5697*da0073e9SAndroid Build Coastguard Worker        # Passing lists with mismatched devices to a raw
5698*da0073e9SAndroid Build Coastguard Worker        # _amp_foreach_non_finite_check_and_unscale_ call should raise errors.
5699*da0073e9SAndroid Build Coastguard Worker        if device.type == "cuda" and TEST_MULTIGPU:
5700*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Expected all tensors to be on the same device"):
5701*da0073e9SAndroid Build Coastguard Worker                torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")],
5702*da0073e9SAndroid Build Coastguard Worker                                                                 found_inf,
5703*da0073e9SAndroid Build Coastguard Worker                                                                 inv_scale)
5704*da0073e9SAndroid Build Coastguard Worker
5705*da0073e9SAndroid Build Coastguard Worker        # Creates a list of grads with mismatched dtypes and devices, to ensure
5706*da0073e9SAndroid Build Coastguard Worker        # scaler._unscale_grads_ organizes grads by dtype and device before calling
5707*da0073e9SAndroid Build Coastguard Worker        # _amp_foreach_non_finite_check_and_unscale_ on each set.
5708*da0073e9SAndroid Build Coastguard Worker        # If inject_inf >= 0, writes an inf into one grad for _unscale_grads_ to find.
5709*da0073e9SAndroid Build Coastguard Worker        def perfect_storm_grads(inject_inf):
5710*da0073e9SAndroid Build Coastguard Worker            grads = [g.clone(), g.clone()[:, :5], g.to(dtype=torch.float16), g.to(dtype=torch.float16)]
5711*da0073e9SAndroid Build Coastguard Worker            if device.type == "cuda" and TEST_MULTIGPU:
5712*da0073e9SAndroid Build Coastguard Worker                grads += [g.to(device="cuda:1"),
5713*da0073e9SAndroid Build Coastguard Worker                          g.to(device="cuda:1")[:, :5],
5714*da0073e9SAndroid Build Coastguard Worker                          g.to(device="cuda:1", dtype=torch.float16),
5715*da0073e9SAndroid Build Coastguard Worker                          g.to(device="cuda:1", dtype=torch.float16)]
5716*da0073e9SAndroid Build Coastguard Worker            if inject_inf >= 0:
5717*da0073e9SAndroid Build Coastguard Worker                grads[inject_inf][2, 2] = float('inf')
5718*da0073e9SAndroid Build Coastguard Worker            return grads
5719*da0073e9SAndroid Build Coastguard Worker
5720*da0073e9SAndroid Build Coastguard Worker        GradScaler = partial(torch.GradScaler, device=device.type)
5721*da0073e9SAndroid Build Coastguard Worker        scaler = GradScaler()
5722*da0073e9SAndroid Build Coastguard Worker        dummy_params = [torch.empty_like(g) for g in perfect_storm_grads(-1)]
5723*da0073e9SAndroid Build Coastguard Worker        dummy_opt = torch.optim.SGD(dummy_params, lr=1.)
5724*da0073e9SAndroid Build Coastguard Worker
5725*da0073e9SAndroid Build Coastguard Worker        # Ensures the inf/nan checking can find an inf injected onto any grad in the perfect storm.
5726*da0073e9SAndroid Build Coastguard Worker        for inject_inf in range(-1, len(dummy_params)):
5727*da0073e9SAndroid Build Coastguard Worker            found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device0)
5728*da0073e9SAndroid Build Coastguard Worker            grads = perfect_storm_grads(inject_inf)
5729*da0073e9SAndroid Build Coastguard Worker            for i, p in enumerate(dummy_params):
5730*da0073e9SAndroid Build Coastguard Worker                p.grad = grads[i]
5731*da0073e9SAndroid Build Coastguard Worker            found_inf_per_device = scaler._unscale_grads_(dummy_opt, inv_scale, found_inf, True)
5732*da0073e9SAndroid Build Coastguard Worker            if inject_inf < 0:
5733*da0073e9SAndroid Build Coastguard Worker                # No inf was injected, ensures unscaling worked normally.
5734*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 0)
5735*da0073e9SAndroid Build Coastguard Worker                for grad in grads:
5736*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
5737*da0073e9SAndroid Build Coastguard Worker            else:
5738*da0073e9SAndroid Build Coastguard Worker                # inf was injected, ensures inf was found.
5739*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 1)
5740*da0073e9SAndroid Build Coastguard Worker
5741*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5742*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
5743*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_update_scale(self, device, dtype):
5744*da0073e9SAndroid Build Coastguard Worker        growth = 2.0
5745*da0073e9SAndroid Build Coastguard Worker        backoff = 0.25
5746*da0073e9SAndroid Build Coastguard Worker        growth_interval = 2
5747*da0073e9SAndroid Build Coastguard Worker        scale = torch.full((1,), 4.0, dtype=dtype, device=device)
5748*da0073e9SAndroid Build Coastguard Worker        growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device)
5749*da0073e9SAndroid Build Coastguard Worker        found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device)
5750*da0073e9SAndroid Build Coastguard Worker
5751*da0073e9SAndroid Build Coastguard Worker        # Simulates 2 consecutive unskipped iterations
5752*da0073e9SAndroid Build Coastguard Worker        torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
5753*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(growth_tracker, 1)
5754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, 4.0)
5755*da0073e9SAndroid Build Coastguard Worker        torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
5756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(growth_tracker, 0)
5757*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, 8.0)
5758*da0073e9SAndroid Build Coastguard Worker
5759*da0073e9SAndroid Build Coastguard Worker        # Simulates a skipped iteration
5760*da0073e9SAndroid Build Coastguard Worker        found_inf.fill_(1.0)
5761*da0073e9SAndroid Build Coastguard Worker        torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
5762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(growth_tracker, 0)
5763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, 2.0)
5764*da0073e9SAndroid Build Coastguard Worker
5765*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Failed running call_function for sparse_coo_tensor. See https://github.com/pytorch/pytorch/issues/118856")
5766*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5767*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
5768*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_unscale_sparse(self, device, dtype):
5769*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5770*da0073e9SAndroid Build Coastguard Worker        scaler = torch.GradScaler(device=device.type)
5771*da0073e9SAndroid Build Coastguard Worker
5772*da0073e9SAndroid Build Coastguard Worker        inv_scale = torch.full((1,), 0.25, dtype=dtype, device=device)
5773*da0073e9SAndroid Build Coastguard Worker        found_inf = torch.empty((1,), dtype=dtype, device=device)
5774*da0073e9SAndroid Build Coastguard Worker        cur = found_inf.device
5775*da0073e9SAndroid Build Coastguard Worker
5776*da0073e9SAndroid Build Coastguard Worker        i = torch.tensor([[0, 1, 1],
5777*da0073e9SAndroid Build Coastguard Worker                          [2, 0, 2]], device=device, dtype=torch.int64)
5778*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([16., 32., 64.], device=device, dtype=torch.float)
5779*da0073e9SAndroid Build Coastguard Worker        s = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype)
5780*da0073e9SAndroid Build Coastguard Worker
5781*da0073e9SAndroid Build Coastguard Worker        p = s.clone()
5782*da0073e9SAndroid Build Coastguard Worker        assert p.is_sparse
5783*da0073e9SAndroid Build Coastguard Worker        opt = torch.optim.SGD([p], lr=1.)
5784*da0073e9SAndroid Build Coastguard Worker
5785*da0073e9SAndroid Build Coastguard Worker        p.grad = s.clone()
5786*da0073e9SAndroid Build Coastguard Worker        found_inf.zero_()
5787*da0073e9SAndroid Build Coastguard Worker        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
5788*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(found_inf, 0.0)
5789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p.grad.to_dense(), (s / 4).to_dense())
5790*da0073e9SAndroid Build Coastguard Worker
5791*da0073e9SAndroid Build Coastguard Worker        v = torch.FloatTensor([16., 32., float('inf')])
5792*da0073e9SAndroid Build Coastguard Worker        p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype)
5793*da0073e9SAndroid Build Coastguard Worker        found_inf.zero_()
5794*da0073e9SAndroid Build Coastguard Worker        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
5795*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(found_inf, 1.0)
5796*da0073e9SAndroid Build Coastguard Worker
5797*da0073e9SAndroid Build Coastguard Worker        v = torch.FloatTensor([16., 32., float('nan')])
5798*da0073e9SAndroid Build Coastguard Worker        p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype)
5799*da0073e9SAndroid Build Coastguard Worker        found_inf.zero_()
5800*da0073e9SAndroid Build Coastguard Worker        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur]
5801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(found_inf, 1.0)
5802*da0073e9SAndroid Build Coastguard Worker
5803*da0073e9SAndroid Build Coastguard Worker        p = s.clone().half()
5804*da0073e9SAndroid Build Coastguard Worker        assert p.is_sparse
5805*da0073e9SAndroid Build Coastguard Worker        opt = torch.optim.SGD([p], lr=1.)
5806*da0073e9SAndroid Build Coastguard Worker
5807*da0073e9SAndroid Build Coastguard Worker        p.grad = s.clone().half()
5808*da0073e9SAndroid Build Coastguard Worker        found_inf.zero_()
5809*da0073e9SAndroid Build Coastguard Worker        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur]
5810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(found_inf, 0.0)
5811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p.grad.to_dense(), (s.half() / 4).to_dense())
5812*da0073e9SAndroid Build Coastguard Worker
5813*da0073e9SAndroid Build Coastguard Worker        # Creates fp16 sparse tensor with duplicated indices (uncoalesced).  The uncoalesced representation
5814*da0073e9SAndroid Build Coastguard Worker        # does not overflow in fp16, but the coalesced representation would, because 64000 + 64000 > fp16 max.
5815*da0073e9SAndroid Build Coastguard Worker        # _amp_non_finite_check_and_unscale_ should report an overflow here.
5816*da0073e9SAndroid Build Coastguard Worker        i = torch.LongTensor([[0, 1, 0],
5817*da0073e9SAndroid Build Coastguard Worker                              [2, 0, 2]])
5818*da0073e9SAndroid Build Coastguard Worker        v = torch.FloatTensor([64000., 32., 64000.])
5819*da0073e9SAndroid Build Coastguard Worker        p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=torch.float16)
5820*da0073e9SAndroid Build Coastguard Worker        found_inf.zero_()
5821*da0073e9SAndroid Build Coastguard Worker        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur]
5822*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(found_inf, 1.0)
5823*da0073e9SAndroid Build Coastguard Worker
5824*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5825*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_state_dict(self, device):
5826*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5827*da0073e9SAndroid Build Coastguard Worker        GradScaler = partial(torch.GradScaler, device=device.type)
5828*da0073e9SAndroid Build Coastguard Worker        for lazy_init_scale in True, False:
5829*da0073e9SAndroid Build Coastguard Worker            s0 = GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2)
5830*da0073e9SAndroid Build Coastguard Worker            s1 = GradScaler(init_scale=6., growth_factor=7., backoff_factor=.8, growth_interval=1)
5831*da0073e9SAndroid Build Coastguard Worker
5832*da0073e9SAndroid Build Coastguard Worker            # sets a random value for load_state_dict to overwrite
5833*da0073e9SAndroid Build Coastguard Worker            s1._init_growth_tracker = 7
5834*da0073e9SAndroid Build Coastguard Worker
5835*da0073e9SAndroid Build Coastguard Worker            if lazy_init_scale:
5836*da0073e9SAndroid Build Coastguard Worker                # Dummy scale() call to ensure the scale tensor is lazily initialized.
5837*da0073e9SAndroid Build Coastguard Worker                s1.scale(torch.full((1,), 4.0, dtype=torch.float32, device=device))
5838*da0073e9SAndroid Build Coastguard Worker                if "cuda" == device.type:
5839*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(isinstance(s1._scale, torch.cuda.FloatTensor))
5840*da0073e9SAndroid Build Coastguard Worker                else:
5841*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(isinstance(s1._scale, torch.FloatTensor))
5842*da0073e9SAndroid Build Coastguard Worker
5843*da0073e9SAndroid Build Coastguard Worker            s1.load_state_dict(s0.state_dict())
5844*da0073e9SAndroid Build Coastguard Worker
5845*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1.get_scale(), 3.)
5846*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1.get_growth_factor(), 4.)
5847*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1.get_backoff_factor(), .5)
5848*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1.get_growth_interval(), 2)
5849*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1._init_growth_tracker, 0)
5850*da0073e9SAndroid Build Coastguard Worker
5851*da0073e9SAndroid Build Coastguard Worker    # _run_scaling_case generalizes some single-optimizer test logic to avoid too much copy-pasting below.
5852*da0073e9SAndroid Build Coastguard Worker    def _run_scaling_case(self, device, run, unskipped, skipped, atol=1e-7, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
5853*da0073e9SAndroid Build Coastguard Worker        # Ensure scaling can be disabled without changing user control flow.
5854*da0073e9SAndroid Build Coastguard Worker        for enabled in True, False:
5855*da0073e9SAndroid Build Coastguard Worker            (
5856*da0073e9SAndroid Build Coastguard Worker                mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, skip_iter,
5857*da0073e9SAndroid Build Coastguard Worker            ) = _create_scaling_case(device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs)
5858*da0073e9SAndroid Build Coastguard Worker
5859*da0073e9SAndroid Build Coastguard Worker            # For functionality, test with a modest initial scale, and an unrealistically-large growth factor
5860*da0073e9SAndroid Build Coastguard Worker            # so any potential errors with the growth factor handling will be magnified.
5861*da0073e9SAndroid Build Coastguard Worker            GradScaler = partial(torch.GradScaler, device=device)
5862*da0073e9SAndroid Build Coastguard Worker            scaler = GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled, growth_interval=1)
5863*da0073e9SAndroid Build Coastguard Worker
5864*da0073e9SAndroid Build Coastguard Worker            _ = run(device, data, mod_control, opt_control, scaler, loss_fn, skip_iter, False)
5865*da0073e9SAndroid Build Coastguard Worker            ret = run(device, data, mod_scaling, opt_scaling, scaler, loss_fn, skip_iter, True)
5866*da0073e9SAndroid Build Coastguard Worker
5867*da0073e9SAndroid Build Coastguard Worker            # Allows run() to optionally return a different scaler instance.
5868*da0073e9SAndroid Build Coastguard Worker            scaler = ret if ret else scaler
5869*da0073e9SAndroid Build Coastguard Worker
5870*da0073e9SAndroid Build Coastguard Worker            # If scaling was enabled, the scale factor should have been multiplied by the growth factor
5871*da0073e9SAndroid Build Coastguard Worker            # len(data) - skipped times and the backoff factor "skipped" times.
5872*da0073e9SAndroid Build Coastguard Worker            if enabled:
5873*da0073e9SAndroid Build Coastguard Worker                net_growth = scaler.get_growth_factor()**unskipped if unskipped > 0 else 1.0
5874*da0073e9SAndroid Build Coastguard Worker                net_backoff = scaler.get_backoff_factor()**skipped if skipped > 0 else 1.0
5875*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(scaler.get_scale() == (128. * net_growth * net_backoff))
5876*da0073e9SAndroid Build Coastguard Worker            else:
5877*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(scaler.get_scale() == 1.0)
5878*da0073e9SAndroid Build Coastguard Worker
5879*da0073e9SAndroid Build Coastguard Worker            for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
5880*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c.grad, s.grad, atol=atol, rtol=1e-05)
5881*da0073e9SAndroid Build Coastguard Worker
5882*da0073e9SAndroid Build Coastguard Worker                c_state, s_state = opt_control.state[c], opt_scaling.state[s]
5883*da0073e9SAndroid Build Coastguard Worker                for k in c_state:
5884*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(c_state[k], s_state[k], atol=atol, rtol=1e-05, msg=k)
5885*da0073e9SAndroid Build Coastguard Worker
5886*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c, s, atol=atol, rtol=1e-05)
5887*da0073e9SAndroid Build Coastguard Worker
5888*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5889*da0073e9SAndroid Build Coastguard Worker    @parametrize("foreach, fused", [(None, None), (True, None), (None, True)])
5890*da0073e9SAndroid Build Coastguard Worker    @optims(
5891*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
5892*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32]
5893*da0073e9SAndroid Build Coastguard Worker    )
5894*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_autocast(self, device, dtype, optim_info, foreach, fused):
5895*da0073e9SAndroid Build Coastguard Worker        try_pickle = False
5896*da0073e9SAndroid Build Coastguard Worker
5897*da0073e9SAndroid Build Coastguard Worker        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
5898*da0073e9SAndroid Build Coastguard Worker            for i, (input, target) in enumerate(data):
5899*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
5900*da0073e9SAndroid Build Coastguard Worker                with torch.autocast(device_type=device, dtype=torch.half, enabled=try_scaling_api):
5901*da0073e9SAndroid Build Coastguard Worker                    output = model(input)
5902*da0073e9SAndroid Build Coastguard Worker                    loss = loss_fn(output, target)
5903*da0073e9SAndroid Build Coastguard Worker                if try_scaling_api:
5904*da0073e9SAndroid Build Coastguard Worker                    scaler.scale(loss).backward()
5905*da0073e9SAndroid Build Coastguard Worker                    if i == skip_iter and scaler.is_enabled():
5906*da0073e9SAndroid Build Coastguard Worker                        with torch.no_grad():
5907*da0073e9SAndroid Build Coastguard Worker                            model[1].weight.grad.fill_(float('inf'))
5908*da0073e9SAndroid Build Coastguard Worker                    scaler.step(optimizer)
5909*da0073e9SAndroid Build Coastguard Worker                    scaler.update()
5910*da0073e9SAndroid Build Coastguard Worker                    if try_pickle:
5911*da0073e9SAndroid Build Coastguard Worker                        scaler = pickle.loads(pickle.dumps(scaler))
5912*da0073e9SAndroid Build Coastguard Worker                else:
5913*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
5914*da0073e9SAndroid Build Coastguard Worker                    if (not scaler.is_enabled()) or (i != skip_iter):
5915*da0073e9SAndroid Build Coastguard Worker                        optimizer.step()
5916*da0073e9SAndroid Build Coastguard Worker            return scaler
5917*da0073e9SAndroid Build Coastguard Worker
5918*da0073e9SAndroid Build Coastguard Worker        optimizer_ctor = optim_info.optim_cls
5919*da0073e9SAndroid Build Coastguard Worker
5920*da0073e9SAndroid Build Coastguard Worker        # Compares no scaling + no autocasting against scaling + autocasting.
5921*da0073e9SAndroid Build Coastguard Worker        # NOTE(mkozuki): With current way of testing, `torch.optim.Adam` is failing in spite of `foreach` and `fused`.
5922*da0073e9SAndroid Build Coastguard Worker        #   Giving some flexibility to this test might help.
5923*da0073e9SAndroid Build Coastguard Worker        context = contextlib.nullcontext
5924*da0073e9SAndroid Build Coastguard Worker        if optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
5925*da0073e9SAndroid Build Coastguard Worker            from functools import partial
5926*da0073e9SAndroid Build Coastguard Worker            context = partial(self.assertRaises, AssertionError)
5927*da0073e9SAndroid Build Coastguard Worker        with context():
5928*da0073e9SAndroid Build Coastguard Worker            # sets atol=1e-3 because we're comparing pure fp32 arithmetic vs a mixture of fp16 and fp32
5929*da0073e9SAndroid Build Coastguard Worker            self._run_scaling_case(
5930*da0073e9SAndroid Build Coastguard Worker                device, run, unskipped=3, skipped=1, atol=1e-3,
5931*da0073e9SAndroid Build Coastguard Worker                optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused},
5932*da0073e9SAndroid Build Coastguard Worker            )
5933*da0073e9SAndroid Build Coastguard Worker            # this will be picked up by try_pickle within run():
5934*da0073e9SAndroid Build Coastguard Worker            try_pickle = True
5935*da0073e9SAndroid Build Coastguard Worker            self._run_scaling_case(
5936*da0073e9SAndroid Build Coastguard Worker                device, run, unskipped=3, skipped=1, atol=1e-3,
5937*da0073e9SAndroid Build Coastguard Worker                optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused},
5938*da0073e9SAndroid Build Coastguard Worker            )
5939*da0073e9SAndroid Build Coastguard Worker
5940*da0073e9SAndroid Build Coastguard Worker    # Make sure that the parameters become nonsense when scaled gradients are finite
5941*da0073e9SAndroid Build Coastguard Worker    # but they get invalidated before `optimizer.step`, after `GradScaler.unscale_`
5942*da0073e9SAndroid Build Coastguard Worker
5943*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5944*da0073e9SAndroid Build Coastguard Worker    @optims(
5945*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
5946*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32]
5947*da0073e9SAndroid Build Coastguard Worker    )
5948*da0073e9SAndroid Build Coastguard Worker    def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info):
5949*da0073e9SAndroid Build Coastguard Worker        optimizer_ctor = optim_info.optim_cls
5950*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
5951*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",))
5952*da0073e9SAndroid Build Coastguard Worker
5953*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
5954*da0073e9SAndroid Build Coastguard Worker            model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case(
5955*da0073e9SAndroid Build Coastguard Worker                device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs,
5956*da0073e9SAndroid Build Coastguard Worker            )
5957*da0073e9SAndroid Build Coastguard Worker            scaler = torch.GradScaler(device=device, init_scale=128.0)
5958*da0073e9SAndroid Build Coastguard Worker
5959*da0073e9SAndroid Build Coastguard Worker            for input, target in data:
5960*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
5961*da0073e9SAndroid Build Coastguard Worker                with torch.autocast(device_type=device, dtype=torch.half):
5962*da0073e9SAndroid Build Coastguard Worker                    output = model(input)
5963*da0073e9SAndroid Build Coastguard Worker                    loss = loss_fn(output, target)
5964*da0073e9SAndroid Build Coastguard Worker                scaler.scale(loss).backward()
5965*da0073e9SAndroid Build Coastguard Worker                scaler.unscale_(optimizer)
5966*da0073e9SAndroid Build Coastguard Worker
5967*da0073e9SAndroid Build Coastguard Worker                # deliberately break grads
5968*da0073e9SAndroid Build Coastguard Worker                for j, param in enumerate(model.parameters()):
5969*da0073e9SAndroid Build Coastguard Worker                    param.grad.copy_(torch.inf if j % 2 else torch.nan)
5970*da0073e9SAndroid Build Coastguard Worker
5971*da0073e9SAndroid Build Coastguard Worker                scaler.step(optimizer)
5972*da0073e9SAndroid Build Coastguard Worker                scaler.update()
5973*da0073e9SAndroid Build Coastguard Worker
5974*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters()))
5975*da0073e9SAndroid Build Coastguard Worker
5976*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5977*da0073e9SAndroid Build Coastguard Worker    def test_grad_scale_will_not_overflow(self, device):
5978*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5979*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(5, 1).to(device)
5980*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.Adam(model.parameters())
5981*da0073e9SAndroid Build Coastguard Worker        scaler = torch.GradScaler(device=device.type, growth_interval=1, growth_factor=2**4, init_scale=1e38)
5982*da0073e9SAndroid Build Coastguard Worker        optimizer.zero_grad()
5983*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 5).to(device)
5984*da0073e9SAndroid Build Coastguard Worker        y = 1e-30 * torch.randn(1, 1).to(device)
5985*da0073e9SAndroid Build Coastguard Worker        l = ((model(x) - y) ** 2).mean()
5986*da0073e9SAndroid Build Coastguard Worker        scaler.scale(l).backward()
5987*da0073e9SAndroid Build Coastguard Worker        scaler.step(optimizer)
5988*da0073e9SAndroid Build Coastguard Worker        scaler.update()
5989*da0073e9SAndroid Build Coastguard Worker        assert scaler._scale != float("inf") and scaler._scale != float("nan")
5990*da0073e9SAndroid Build Coastguard Worker
5991*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5992*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_clipping(self, device):
5993*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
5994*da0073e9SAndroid Build Coastguard Worker
5995*da0073e9SAndroid Build Coastguard Worker        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
5996*da0073e9SAndroid Build Coastguard Worker            max_norm = 0.2  # A reasonable value that actually has an effect, based on printouts of grads
5997*da0073e9SAndroid Build Coastguard Worker            for i, (input, target) in enumerate(data):
5998*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
5999*da0073e9SAndroid Build Coastguard Worker                output = model(input)
6000*da0073e9SAndroid Build Coastguard Worker                loss = loss_fn(output, target)
6001*da0073e9SAndroid Build Coastguard Worker                if try_scaling_api:
6002*da0073e9SAndroid Build Coastguard Worker                    scaler.scale(loss).backward()
6003*da0073e9SAndroid Build Coastguard Worker                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm * scaler.get_scale())
6004*da0073e9SAndroid Build Coastguard Worker                    if i == skip_iter and scaler.is_enabled():
6005*da0073e9SAndroid Build Coastguard Worker                        model[1].weight.grad.data.fill_(float('inf'))
6006*da0073e9SAndroid Build Coastguard Worker                    scaler.step(optimizer)
6007*da0073e9SAndroid Build Coastguard Worker                    scaler.update()
6008*da0073e9SAndroid Build Coastguard Worker                else:
6009*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
6010*da0073e9SAndroid Build Coastguard Worker                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
6011*da0073e9SAndroid Build Coastguard Worker                    if (not scaler.is_enabled()) or (i != skip_iter):
6012*da0073e9SAndroid Build Coastguard Worker                        optimizer.step()
6013*da0073e9SAndroid Build Coastguard Worker
6014*da0073e9SAndroid Build Coastguard Worker        self._run_scaling_case(device.type, run, unskipped=3, skipped=1, atol=1e-5)
6015*da0073e9SAndroid Build Coastguard Worker
6016*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6017*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_clipping_separate_unscale(self, device):
6018*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
6019*da0073e9SAndroid Build Coastguard Worker
6020*da0073e9SAndroid Build Coastguard Worker        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
6021*da0073e9SAndroid Build Coastguard Worker            max_norm = 0.2  # A reasonable value that actually has an effect, based on printouts of grads
6022*da0073e9SAndroid Build Coastguard Worker            for i, (input, target) in enumerate(data):
6023*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
6024*da0073e9SAndroid Build Coastguard Worker                output = model(input)
6025*da0073e9SAndroid Build Coastguard Worker                loss = loss_fn(output, target)
6026*da0073e9SAndroid Build Coastguard Worker                if try_scaling_api:
6027*da0073e9SAndroid Build Coastguard Worker                    scaler.scale(loss).backward()
6028*da0073e9SAndroid Build Coastguard Worker                    if i == skip_iter and scaler.is_enabled():
6029*da0073e9SAndroid Build Coastguard Worker                        model[1].weight.grad.data.fill_(float('inf'))
6030*da0073e9SAndroid Build Coastguard Worker                    scaler.unscale_(optimizer)
6031*da0073e9SAndroid Build Coastguard Worker                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, error_if_nonfinite=False)
6032*da0073e9SAndroid Build Coastguard Worker                    scaler.step(optimizer)
6033*da0073e9SAndroid Build Coastguard Worker                    scaler.update()
6034*da0073e9SAndroid Build Coastguard Worker                else:
6035*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
6036*da0073e9SAndroid Build Coastguard Worker                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
6037*da0073e9SAndroid Build Coastguard Worker                    if (not scaler.is_enabled()) or (i != skip_iter):
6038*da0073e9SAndroid Build Coastguard Worker                        optimizer.step()
6039*da0073e9SAndroid Build Coastguard Worker
6040*da0073e9SAndroid Build Coastguard Worker        self._run_scaling_case(device.type, run, unskipped=3, skipped=1)
6041*da0073e9SAndroid Build Coastguard Worker
6042*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6043*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, 'FIXME: fix this test for Windows')
6044*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_penalty(self, device):
6045*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
6046*da0073e9SAndroid Build Coastguard Worker
6047*da0073e9SAndroid Build Coastguard Worker        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
6048*da0073e9SAndroid Build Coastguard Worker            for i, (input, target) in enumerate(data):
6049*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
6050*da0073e9SAndroid Build Coastguard Worker                output = model(input)
6051*da0073e9SAndroid Build Coastguard Worker                loss = loss_fn(output, target)
6052*da0073e9SAndroid Build Coastguard Worker
6053*da0073e9SAndroid Build Coastguard Worker                if try_scaling_api:
6054*da0073e9SAndroid Build Coastguard Worker                    grad_params = torch.autograd.grad(scaler.scale(loss),
6055*da0073e9SAndroid Build Coastguard Worker                                                      model.parameters(), create_graph=True)
6056*da0073e9SAndroid Build Coastguard Worker                    inv_scale = 1. / scaler.get_scale()
6057*da0073e9SAndroid Build Coastguard Worker                    grad_params = [p * inv_scale for p in grad_params]
6058*da0073e9SAndroid Build Coastguard Worker                else:
6059*da0073e9SAndroid Build Coastguard Worker                    grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
6060*da0073e9SAndroid Build Coastguard Worker
6061*da0073e9SAndroid Build Coastguard Worker                grad_norm = 0
6062*da0073e9SAndroid Build Coastguard Worker                for grad in grad_params:
6063*da0073e9SAndroid Build Coastguard Worker                    grad_norm += grad.pow(2).sum()
6064*da0073e9SAndroid Build Coastguard Worker                grad_norm = grad_norm.sqrt()
6065*da0073e9SAndroid Build Coastguard Worker                loss = loss + grad_norm
6066*da0073e9SAndroid Build Coastguard Worker
6067*da0073e9SAndroid Build Coastguard Worker                if try_scaling_api:
6068*da0073e9SAndroid Build Coastguard Worker                    scaler.scale(loss).backward()
6069*da0073e9SAndroid Build Coastguard Worker                    if i == skip_iter and scaler.is_enabled():
6070*da0073e9SAndroid Build Coastguard Worker                        model[1].weight.grad.data.fill_(float('inf'))
6071*da0073e9SAndroid Build Coastguard Worker                    scaler.step(optimizer)
6072*da0073e9SAndroid Build Coastguard Worker                    scaler.update()
6073*da0073e9SAndroid Build Coastguard Worker                else:
6074*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
6075*da0073e9SAndroid Build Coastguard Worker                    if (not scaler.is_enabled()) or (i != skip_iter):
6076*da0073e9SAndroid Build Coastguard Worker                        optimizer.step()
6077*da0073e9SAndroid Build Coastguard Worker
6078*da0073e9SAndroid Build Coastguard Worker        self._run_scaling_case(device.type, run, unskipped=3, skipped=1)
6079*da0073e9SAndroid Build Coastguard Worker
6080*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6081*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_accumulation(self, device):
6082*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
6083*da0073e9SAndroid Build Coastguard Worker
6084*da0073e9SAndroid Build Coastguard Worker        def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
6085*da0073e9SAndroid Build Coastguard Worker            iters_to_accumulate = 2
6086*da0073e9SAndroid Build Coastguard Worker            for i, (input, target) in enumerate(data):
6087*da0073e9SAndroid Build Coastguard Worker                output = model(input)
6088*da0073e9SAndroid Build Coastguard Worker                loss = loss_fn(output, target)
6089*da0073e9SAndroid Build Coastguard Worker                loss = loss / iters_to_accumulate
6090*da0073e9SAndroid Build Coastguard Worker                if try_scaling_api:
6091*da0073e9SAndroid Build Coastguard Worker                    scaler.scale(loss).backward()
6092*da0073e9SAndroid Build Coastguard Worker                else:
6093*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
6094*da0073e9SAndroid Build Coastguard Worker                if (i + 1) % iters_to_accumulate == 0:
6095*da0073e9SAndroid Build Coastguard Worker                    if try_scaling_api:
6096*da0073e9SAndroid Build Coastguard Worker                        scaler.step(optimizer)
6097*da0073e9SAndroid Build Coastguard Worker                        scaler.update()
6098*da0073e9SAndroid Build Coastguard Worker                        optimizer.zero_grad()
6099*da0073e9SAndroid Build Coastguard Worker                    else:
6100*da0073e9SAndroid Build Coastguard Worker                        optimizer.step()
6101*da0073e9SAndroid Build Coastguard Worker                        optimizer.zero_grad()
6102*da0073e9SAndroid Build Coastguard Worker
6103*da0073e9SAndroid Build Coastguard Worker        self._run_scaling_case(device.type, run, unskipped=2, skipped=0)
6104*da0073e9SAndroid Build Coastguard Worker
6105*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6106*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_multiple(self, device):
6107*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
6108*da0073e9SAndroid Build Coastguard Worker        # Tests gradient scaling with 2 models and 2 optimizers that both receive gradients from 2 losses.
6109*da0073e9SAndroid Build Coastguard Worker        # Some of the logic here cannot reuse the generic helper functions created for the 1-optimizer cases.
6110*da0073e9SAndroid Build Coastguard Worker        for enabled in True, False:
6111*da0073e9SAndroid Build Coastguard Worker            mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter = \
6112*da0073e9SAndroid Build Coastguard Worker                _create_scaling_case(device.type)
6113*da0073e9SAndroid Build Coastguard Worker            mod_control1, mod_scaling1, opt_control1, opt_scaling1 = \
6114*da0073e9SAndroid Build Coastguard Worker                _create_scaling_models_optimizers(device.type)
6115*da0073e9SAndroid Build Coastguard Worker
6116*da0073e9SAndroid Build Coastguard Worker            GradScaler = partial(torch.GradScaler, device=device.type)
6117*da0073e9SAndroid Build Coastguard Worker            scaler = GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled, growth_interval=1)
6118*da0073e9SAndroid Build Coastguard Worker
6119*da0073e9SAndroid Build Coastguard Worker            def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
6120*da0073e9SAndroid Build Coastguard Worker                for i, (input, target) in enumerate(data):
6121*da0073e9SAndroid Build Coastguard Worker                    optimizer0.zero_grad()
6122*da0073e9SAndroid Build Coastguard Worker                    optimizer1.zero_grad()
6123*da0073e9SAndroid Build Coastguard Worker                    output0 = model0(input)
6124*da0073e9SAndroid Build Coastguard Worker                    output1 = model1(input)
6125*da0073e9SAndroid Build Coastguard Worker                    loss0 = loss_fn(0.3 * output0 + 0.7 * output1, target)
6126*da0073e9SAndroid Build Coastguard Worker                    loss1 = loss_fn(0.6 * output0 - 0.4 * output1, target)
6127*da0073e9SAndroid Build Coastguard Worker
6128*da0073e9SAndroid Build Coastguard Worker                    if try_scaling_api:
6129*da0073e9SAndroid Build Coastguard Worker                        scaler.scale(loss0).backward(retain_graph=True)
6130*da0073e9SAndroid Build Coastguard Worker                        scaler.scale(loss1).backward()
6131*da0073e9SAndroid Build Coastguard Worker                        if i == skip_iter and scaler.is_enabled():
6132*da0073e9SAndroid Build Coastguard Worker                            model1[1].weight.grad.data.fill_(float('inf'))
6133*da0073e9SAndroid Build Coastguard Worker
6134*da0073e9SAndroid Build Coastguard Worker                        # As an additional stress test, separately unscale for one of the optimizers.
6135*da0073e9SAndroid Build Coastguard Worker                        scaler.unscale_(optimizer0)
6136*da0073e9SAndroid Build Coastguard Worker
6137*da0073e9SAndroid Build Coastguard Worker                        scaler.step(optimizer0)
6138*da0073e9SAndroid Build Coastguard Worker                        scaler.step(optimizer1)
6139*da0073e9SAndroid Build Coastguard Worker                        scaler.update()
6140*da0073e9SAndroid Build Coastguard Worker                    else:
6141*da0073e9SAndroid Build Coastguard Worker                        loss0.backward(retain_graph=True)
6142*da0073e9SAndroid Build Coastguard Worker                        loss1.backward()
6143*da0073e9SAndroid Build Coastguard Worker                        optimizer0.step()
6144*da0073e9SAndroid Build Coastguard Worker                        if (not scaler.is_enabled()) or (i != skip_iter):
6145*da0073e9SAndroid Build Coastguard Worker                            optimizer1.step()
6146*da0073e9SAndroid Build Coastguard Worker
6147*da0073e9SAndroid Build Coastguard Worker            run(mod_control0, mod_control1, opt_control0, opt_control1, False)
6148*da0073e9SAndroid Build Coastguard Worker            run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
6149*da0073e9SAndroid Build Coastguard Worker
6150*da0073e9SAndroid Build Coastguard Worker            # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once.
6151*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(scaler.get_scale() == (128. * scaler.get_growth_factor()**3 *
6152*da0073e9SAndroid Build Coastguard Worker                                                   scaler.get_backoff_factor()**1) if enabled else 1.0)
6153*da0073e9SAndroid Build Coastguard Worker
6154*da0073e9SAndroid Build Coastguard Worker            for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()),
6155*da0073e9SAndroid Build Coastguard Worker                            chain(mod_scaling0.parameters(), mod_scaling1.parameters())):
6156*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c, s, rtol=1e-5, atol=1e-7)
6157*da0073e9SAndroid Build Coastguard Worker
6158*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6159*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaler_pass_itself(self, device):
6160*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
6161*da0073e9SAndroid Build Coastguard Worker        GradScaler = partial(torch.amp.GradScaler, device=device.type)
6162*da0073e9SAndroid Build Coastguard Worker
6163*da0073e9SAndroid Build Coastguard Worker        class _PlaceHolderOptimizer(torch.optim.Optimizer):
6164*da0073e9SAndroid Build Coastguard Worker            tester = self
6165*da0073e9SAndroid Build Coastguard Worker
6166*da0073e9SAndroid Build Coastguard Worker            def __init__(self, params, defaults=None):
6167*da0073e9SAndroid Build Coastguard Worker                if defaults is None:
6168*da0073e9SAndroid Build Coastguard Worker                    defaults = {}
6169*da0073e9SAndroid Build Coastguard Worker                super().__init__(params, defaults)
6170*da0073e9SAndroid Build Coastguard Worker                self._step_supports_amp_scaling = True
6171*da0073e9SAndroid Build Coastguard Worker
6172*da0073e9SAndroid Build Coastguard Worker        class Optimizer1(_PlaceHolderOptimizer):
6173*da0073e9SAndroid Build Coastguard Worker            def step(self, closure=None, *, grad_scaler=None):
6174*da0073e9SAndroid Build Coastguard Worker                self.tester.assertTrue(isinstance(grad_scaler, torch.amp.GradScaler))
6175*da0073e9SAndroid Build Coastguard Worker                self.tester.assertFalse(hasattr(self, "grad_scale"))
6176*da0073e9SAndroid Build Coastguard Worker                self.tester.assertFalse(hasattr(self, "found_inf"))
6177*da0073e9SAndroid Build Coastguard Worker
6178*da0073e9SAndroid Build Coastguard Worker        class Optimizer2(_PlaceHolderOptimizer):
6179*da0073e9SAndroid Build Coastguard Worker            def step(self, closure=None):
6180*da0073e9SAndroid Build Coastguard Worker                self.tester.assertTrue(isinstance(self.grad_scale, torch.Tensor))
6181*da0073e9SAndroid Build Coastguard Worker                self.tester.assertTrue(isinstance(self.found_inf, torch.Tensor))
6182*da0073e9SAndroid Build Coastguard Worker
6183*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4).to(device)
6184*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Linear(4, 1).to(device)
6185*da0073e9SAndroid Build Coastguard Worker        o1 = Optimizer1(m.parameters())
6186*da0073e9SAndroid Build Coastguard Worker        o2 = Optimizer2(m.parameters())
6187*da0073e9SAndroid Build Coastguard Worker        scaler = GradScaler(init_scale=2.0)
6188*da0073e9SAndroid Build Coastguard Worker
6189*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type=device.type, dtype=torch.half):
6190*da0073e9SAndroid Build Coastguard Worker            y = m(x)
6191*da0073e9SAndroid Build Coastguard Worker            loss = y.mean()
6192*da0073e9SAndroid Build Coastguard Worker        scaler.scale(loss).backward()
6193*da0073e9SAndroid Build Coastguard Worker        with self.assertWarns(FutureWarning):
6194*da0073e9SAndroid Build Coastguard Worker            scaler.step(o1)
6195*da0073e9SAndroid Build Coastguard Worker        scaler.step(o2)
6196*da0073e9SAndroid Build Coastguard Worker        scaler.update()
6197*da0073e9SAndroid Build Coastguard Worker
6198*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6199*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaler_deprecated_warning(self, device):
6200*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
6201*da0073e9SAndroid Build Coastguard Worker        GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler
6202*da0073e9SAndroid Build Coastguard Worker
6203*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
6204*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
6205*da0073e9SAndroid Build Coastguard Worker            rf"`torch.{device.type}.amp.GradScaler\(args...\)` is deprecated.",
6206*da0073e9SAndroid Build Coastguard Worker        ):
6207*da0073e9SAndroid Build Coastguard Worker            _ = GradScaler(init_scale=2.0)
6208*da0073e9SAndroid Build Coastguard Worker
6209*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.double, torch.half)
6210*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
6211*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
6212*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_cpu(self, device, dtype):
6213*da0073e9SAndroid Build Coastguard Worker        def make_prob_dist(shape, is_contiguous):
6214*da0073e9SAndroid Build Coastguard Worker            if is_contiguous:
6215*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half or dtype == torch.bfloat16:
6216*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros(shape, device=device).uniform_().to(dtype=dtype)
6217*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(shape, device=device, dtype=dtype).uniform_()
6218*da0073e9SAndroid Build Coastguard Worker            elif len(shape) == 1:
6219*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half or dtype == torch.bfloat16:
6220*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=dtype)[:, 2]
6221*da0073e9SAndroid Build Coastguard Worker                return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2]
6222*da0073e9SAndroid Build Coastguard Worker            else:
6223*da0073e9SAndroid Build Coastguard Worker                # num dim = 2
6224*da0073e9SAndroid Build Coastguard Worker                new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
6225*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half or dtype == torch.bfloat16:
6226*da0073e9SAndroid Build Coastguard Worker                    prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=dtype)
6227*da0073e9SAndroid Build Coastguard Worker                else:
6228*da0073e9SAndroid Build Coastguard Worker                    prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_()
6229*da0073e9SAndroid Build Coastguard Worker                prob_dist = prob_dist.transpose(1, 4)
6230*da0073e9SAndroid Build Coastguard Worker                prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
6231*da0073e9SAndroid Build Coastguard Worker                assert not prob_dist.is_contiguous()  # sanity check
6232*da0073e9SAndroid Build Coastguard Worker                return prob_dist
6233*da0073e9SAndroid Build Coastguard Worker
6234*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to elementwise ternary test suite
6235*da0073e9SAndroid Build Coastguard Worker    # As the test fails with Runtime Error not raised on XLA
6236*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6237*da0073e9SAndroid Build Coastguard Worker    def test_where_scalar_handcrafted_values(self, device):
6238*da0073e9SAndroid Build Coastguard Worker        # Tests ScalarxScalar, ScalarxTensor and TensorxScalar
6239*da0073e9SAndroid Build Coastguard Worker        # variant of `where` against NumPy version with
6240*da0073e9SAndroid Build Coastguard Worker        # handcrafted values.
6241*da0073e9SAndroid Build Coastguard Worker        condition_shape = (5, 5)
6242*da0073e9SAndroid Build Coastguard Worker        dtypes = (
6243*da0073e9SAndroid Build Coastguard Worker            torch.bool, torch.uint8, torch.int8, torch.int16, torch.int64,
6244*da0073e9SAndroid Build Coastguard Worker            torch.float16, torch.float32, torch.float64,
6245*da0073e9SAndroid Build Coastguard Worker            torch.complex64, torch.complex128,
6246*da0073e9SAndroid Build Coastguard Worker        )
6247*da0073e9SAndroid Build Coastguard Worker        shapes = ((), (5,), (1, 5),)
6248*da0073e9SAndroid Build Coastguard Worker
6249*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
6250*da0073e9SAndroid Build Coastguard Worker            tensors = (torch.empty(shape, dtype=dtype, device=device).fill_(17)
6251*da0073e9SAndroid Build Coastguard Worker                       for shape, dtype in product(shapes, dtypes))
6252*da0073e9SAndroid Build Coastguard Worker
6253*da0073e9SAndroid Build Coastguard Worker        # Use different values for `x` and `y`
6254*da0073e9SAndroid Build Coastguard Worker        # as they are the output values which are compared.
6255*da0073e9SAndroid Build Coastguard Worker        x_vals = (True, 3, 7.0, 1 + 0.5j)
6256*da0073e9SAndroid Build Coastguard Worker        y_vals = itertools.chain((False, 4, 8.0, 2 + 0.5j), tensors)
6257*da0073e9SAndroid Build Coastguard Worker        for x in x_vals:
6258*da0073e9SAndroid Build Coastguard Worker            for y in y_vals:
6259*da0073e9SAndroid Build Coastguard Worker                condition = torch.empty(*condition_shape, dtype=torch.bool, device=device).bernoulli_()
6260*da0073e9SAndroid Build Coastguard Worker                common_dtype = torch.result_type(x, y)
6261*da0073e9SAndroid Build Coastguard Worker
6262*da0073e9SAndroid Build Coastguard Worker                def check_equal(condition, x, y):
6263*da0073e9SAndroid Build Coastguard Worker                    condition_np = condition.cpu().numpy()
6264*da0073e9SAndroid Build Coastguard Worker                    x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
6265*da0073e9SAndroid Build Coastguard Worker                    y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
6266*da0073e9SAndroid Build Coastguard Worker
6267*da0073e9SAndroid Build Coastguard Worker                    # NumPy aggressively promotes to double, hence cast to output to correct dtype
6268*da0073e9SAndroid Build Coastguard Worker                    expected = torch.from_numpy(np.where(condition_np, x_np, y_np)).to(common_dtype)
6269*da0073e9SAndroid Build Coastguard Worker                    result = torch.where(condition, x, y)
6270*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected, result)
6271*da0073e9SAndroid Build Coastguard Worker
6272*da0073e9SAndroid Build Coastguard Worker                check_equal(condition, x, y)
6273*da0073e9SAndroid Build Coastguard Worker                check_equal(condition, y, x)
6274*da0073e9SAndroid Build Coastguard Worker                if self.device_type == "cuda":
6275*da0073e9SAndroid Build Coastguard Worker                    check_equal(condition, torch.tensor(x), y)
6276*da0073e9SAndroid Build Coastguard Worker                    check_equal(condition, y, torch.tensor(x))
6277*da0073e9SAndroid Build Coastguard Worker                    if not isinstance(y, torch.Tensor):
6278*da0073e9SAndroid Build Coastguard Worker                        check_equal(condition, torch.tensor(y), torch.tensor(x))
6279*da0073e9SAndroid Build Coastguard Worker                    if isinstance(y, torch.Tensor) and y.ndim > 0:
6280*da0073e9SAndroid Build Coastguard Worker                        check_equal(torch.tensor(True), x, y)
6281*da0073e9SAndroid Build Coastguard Worker                        check_equal(torch.tensor(True), y, x)
6282*da0073e9SAndroid Build Coastguard Worker
6283*da0073e9SAndroid Build Coastguard Worker
6284*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("FIXME")
6285*da0073e9SAndroid Build Coastguard Worker    def test_hook_remove(self, device):
6286*da0073e9SAndroid Build Coastguard Worker        # Reference: https://github.com/pytorch/pytorch/issues/58354
6287*da0073e9SAndroid Build Coastguard Worker        def _test_helper(remove_hook):
6288*da0073e9SAndroid Build Coastguard Worker            def install_hook(tensor):
6289*da0073e9SAndroid Build Coastguard Worker                handle = None
6290*da0073e9SAndroid Build Coastguard Worker
6291*da0073e9SAndroid Build Coastguard Worker                def hook(tensor):
6292*da0073e9SAndroid Build Coastguard Worker                    if remove_hook:
6293*da0073e9SAndroid Build Coastguard Worker                        handle.remove()
6294*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros_like(tensor)
6295*da0073e9SAndroid Build Coastguard Worker                handle = tensor.register_hook(hook)
6296*da0073e9SAndroid Build Coastguard Worker
6297*da0073e9SAndroid Build Coastguard Worker            t = torch.ones((1, 5), device=device, requires_grad=True)
6298*da0073e9SAndroid Build Coastguard Worker            install_hook(t)
6299*da0073e9SAndroid Build Coastguard Worker
6300*da0073e9SAndroid Build Coastguard Worker            # First call to backward
6301*da0073e9SAndroid Build Coastguard Worker            t.mean().backward()
6302*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.grad, torch.zeros_like(t))
6303*da0073e9SAndroid Build Coastguard Worker
6304*da0073e9SAndroid Build Coastguard Worker            # Second call to backward
6305*da0073e9SAndroid Build Coastguard Worker            t.mean().backward()
6306*da0073e9SAndroid Build Coastguard Worker            if remove_hook:
6307*da0073e9SAndroid Build Coastguard Worker                # After removing the hook, make sure the usual gradient is returned
6308*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.grad, 0.2 * torch.ones_like(t))
6309*da0073e9SAndroid Build Coastguard Worker            else:
6310*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.grad, torch.zeros_like(t))
6311*da0073e9SAndroid Build Coastguard Worker
6312*da0073e9SAndroid Build Coastguard Worker        _test_helper(remove_hook=True)
6313*da0073e9SAndroid Build Coastguard Worker        _test_helper(remove_hook=False)
6314*da0073e9SAndroid Build Coastguard Worker
6315*da0073e9SAndroid Build Coastguard Worker    # FIXME: get PyTorch/XLA to run test_testing
6316*da0073e9SAndroid Build Coastguard Worker    # This test should ideally be in test_testing.py,
6317*da0073e9SAndroid Build Coastguard Worker    # but since pytorch/xla runs tests from test_torch.py, we have it here.
6318*da0073e9SAndroid Build Coastguard Worker    @skipXLA
6319*da0073e9SAndroid Build Coastguard Worker    def test_skip_xla(self, device):
6320*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'xla':
6321*da0073e9SAndroid Build Coastguard Worker            # Should not reach here!
6322*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False)
6323*da0073e9SAndroid Build Coastguard Worker
6324*da0073e9SAndroid Build Coastguard Worker    # FIXME: get PyTorch/XLA to run test_testing
6325*da0073e9SAndroid Build Coastguard Worker    # This test should ideally be in test_testing.py,
6326*da0073e9SAndroid Build Coastguard Worker    # but since pytorch/xla runs tests from test_torch.py, we have it here.
6327*da0073e9SAndroid Build Coastguard Worker    @expectedFailureXLA
6328*da0073e9SAndroid Build Coastguard Worker    def test_expected_failure_xla(self, device):
6329*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'xla':
6330*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False)
6331*da0073e9SAndroid Build Coastguard Worker
6332*da0073e9SAndroid Build Coastguard Worker    # FIXME: get PyTorch/XLA to run test_testing
6333*da0073e9SAndroid Build Coastguard Worker    # This test should ideally be in test_testing.py,
6334*da0073e9SAndroid Build Coastguard Worker    # but since pytorch/xla runs tests from test_torch.py, we have it here.
6335*da0073e9SAndroid Build Coastguard Worker    def test_assertRaisesRegex_ignore_msg_non_native_device(self, device):
6336*da0073e9SAndroid Build Coastguard Worker        # Verify that self.assertRaisesRegex only checks the Error and ignores
6337*da0073e9SAndroid Build Coastguard Worker        # message for non-native devices.
6338*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10, 3), device=device)
6339*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3)
6340*da0073e9SAndroid Build Coastguard Worker        invalid_weight = torch.randn(4, device=device)
6341*da0073e9SAndroid Build Coastguard Worker        msg = "weight tensor should be defined either for all 3 classes or no classes"
6342*da0073e9SAndroid Build Coastguard Worker
6343*da0073e9SAndroid Build Coastguard Worker        # XLA raises RuntimeError with a different message.
6344*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
6345*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.nll_loss(x, t, weight=invalid_weight)
6346*da0073e9SAndroid Build Coastguard Worker
6347*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32))
6348*da0073e9SAndroid Build Coastguard Worker    def test_copy_(self, device, dtype):
6349*da0073e9SAndroid Build Coastguard Worker        def can_cast(src_dtype, dst_dtype):
6350*da0073e9SAndroid Build Coastguard Worker            # torch.can_cast(torch.int16, torch.uint8) returns True
6351*da0073e9SAndroid Build Coastguard Worker            # which isn't actually safe-cast.
6352*da0073e9SAndroid Build Coastguard Worker            # This function returns False in this case.
6353*da0073e9SAndroid Build Coastguard Worker            def is_unsigned_int(dtype):
6354*da0073e9SAndroid Build Coastguard Worker                return dtype is torch.uint8
6355*da0073e9SAndroid Build Coastguard Worker
6356*da0073e9SAndroid Build Coastguard Worker            if is_unsigned_int(dst_dtype):
6357*da0073e9SAndroid Build Coastguard Worker                return is_unsigned_int(src_dtype)
6358*da0073e9SAndroid Build Coastguard Worker            return torch.can_cast(src_dtype, dst_dtype)
6359*da0073e9SAndroid Build Coastguard Worker
6360*da0073e9SAndroid Build Coastguard Worker        def make_tensor_wrapper(shape, dtype):
6361*da0073e9SAndroid Build Coastguard Worker            if dtype is not torch.complex32:
6362*da0073e9SAndroid Build Coastguard Worker                # Make tensor does not support generating
6363*da0073e9SAndroid Build Coastguard Worker                # complex32 tensor
6364*da0073e9SAndroid Build Coastguard Worker                return make_tensor(shape, device=device, dtype=dtype)
6365*da0073e9SAndroid Build Coastguard Worker            return torch.randn(shape, device=device, dtype=dtype)
6366*da0073e9SAndroid Build Coastguard Worker
6367*da0073e9SAndroid Build Coastguard Worker        t = make_tensor_wrapper((50,), dtype)
6368*da0073e9SAndroid Build Coastguard Worker        src_dtypes = all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)
6369*da0073e9SAndroid Build Coastguard Worker        for src_dtype in src_dtypes:
6370*da0073e9SAndroid Build Coastguard Worker            src = make_tensor_wrapper((50,), dtype=src_dtype)
6371*da0073e9SAndroid Build Coastguard Worker            t.copy_(src)
6372*da0073e9SAndroid Build Coastguard Worker            dst = make_tensor_wrapper((50, ), dtype=src_dtype)
6373*da0073e9SAndroid Build Coastguard Worker            if can_cast(src_dtype, dtype):
6374*da0073e9SAndroid Build Coastguard Worker                rtol = None
6375*da0073e9SAndroid Build Coastguard Worker                atol = None
6376*da0073e9SAndroid Build Coastguard Worker                if dtype in (torch.half, torch.complex32):
6377*da0073e9SAndroid Build Coastguard Worker                    rtol = 1e-3
6378*da0073e9SAndroid Build Coastguard Worker                    atol = 1e-3
6379*da0073e9SAndroid Build Coastguard Worker                if dtype in (torch.bfloat16,):
6380*da0073e9SAndroid Build Coastguard Worker                    rtol = 1e-2
6381*da0073e9SAndroid Build Coastguard Worker                    atol = 1e-2
6382*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol)
6383*da0073e9SAndroid Build Coastguard Worker
6384*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(
6385*da0073e9SAndroid Build Coastguard Worker        torch.bool, torch.half, torch.bfloat16, torch.complex32,
6386*da0073e9SAndroid Build Coastguard Worker        torch.uint16, torch.uint32, torch.uint64))
6387*da0073e9SAndroid Build Coastguard Worker    def test_item(self, device, dtype):
6388*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'xla' and dtype in [torch.uint16, torch.uint32, torch.uint64]:
6389*da0073e9SAndroid Build Coastguard Worker            self.skipTest('uint16,32,64 not implemented on XLA')
6390*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((), device=device, dtype=dtype)
6391*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, t.item())
6392*da0073e9SAndroid Build Coastguard Worker
6393*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
6394*da0073e9SAndroid Build Coastguard Worker    def test_masked_scatter_inplace_noncontiguous(self, device):
6395*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(5, 2, dtype=torch.long, device=device)
6396*da0073e9SAndroid Build Coastguard Worker        t_non_contig = t.transpose(0, 1)
6397*da0073e9SAndroid Build Coastguard Worker        t_contig = t_non_contig.contiguous()
6398*da0073e9SAndroid Build Coastguard Worker
6399*da0073e9SAndroid Build Coastguard Worker        assert t_contig.is_contiguous()
6400*da0073e9SAndroid Build Coastguard Worker        assert not t_non_contig.is_contiguous()
6401*da0073e9SAndroid Build Coastguard Worker
6402*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([[False, True], [False, True], [False, False], [True, True], [True, True]], device=device)
6403*da0073e9SAndroid Build Coastguard Worker        mask_non_contig = mask.transpose(0, 1)
6404*da0073e9SAndroid Build Coastguard Worker        mask_contig = mask_non_contig.contiguous()
6405*da0073e9SAndroid Build Coastguard Worker
6406*da0073e9SAndroid Build Coastguard Worker        assert mask_contig.is_contiguous()
6407*da0073e9SAndroid Build Coastguard Worker        assert not mask_non_contig.is_contiguous()
6408*da0073e9SAndroid Build Coastguard Worker
6409*da0073e9SAndroid Build Coastguard Worker        # source is always converted to contiguous by the op.
6410*da0073e9SAndroid Build Coastguard Worker        source = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 9]], device=device)
6411*da0073e9SAndroid Build Coastguard Worker
6412*da0073e9SAndroid Build Coastguard Worker        # t: contig, mask: contig
6413*da0073e9SAndroid Build Coastguard Worker        expected = t_contig.masked_scatter_(mask_contig, source)
6414*da0073e9SAndroid Build Coastguard Worker
6415*da0073e9SAndroid Build Coastguard Worker        # t: non-contig, mask: non-contig
6416*da0073e9SAndroid Build Coastguard Worker        actual = t_non_contig.masked_scatter_(mask_non_contig, source)
6417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
6418*da0073e9SAndroid Build Coastguard Worker
6419*da0073e9SAndroid Build Coastguard Worker        # t: contig, mask: non-contig
6420*da0073e9SAndroid Build Coastguard Worker        actual = t_contig.masked_scatter_(mask_non_contig, source)
6421*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
6422*da0073e9SAndroid Build Coastguard Worker
6423*da0073e9SAndroid Build Coastguard Worker        # t: non-contig, mask: contig
6424*da0073e9SAndroid Build Coastguard Worker        actual = t_non_contig.masked_scatter_(mask_contig, source)
6425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
6426*da0073e9SAndroid Build Coastguard Worker
6427*da0073e9SAndroid Build Coastguard Worker
6428*da0073e9SAndroid Build Coastguard Worker# Tests that compare a device's computation with the (gold-standard) CPU's.
6429*da0073e9SAndroid Build Coastguard Workerclass TestDevicePrecision(TestCase):
6430*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
6431*da0073e9SAndroid Build Coastguard Worker
6432*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to indexing test suite
6433*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
6434*da0073e9SAndroid Build Coastguard Worker    def test_index_add_bfloat16(self, device):
6435*da0073e9SAndroid Build Coastguard Worker        inp_tensor = torch.randn(5, 3, device='cpu').bfloat16()
6436*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu')
6437*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([0, 4, 2], device='cpu')
6438*da0073e9SAndroid Build Coastguard Worker        out_cpu = inp_tensor.index_add(0, index, t)
6439*da0073e9SAndroid Build Coastguard Worker
6440*da0073e9SAndroid Build Coastguard Worker        inp_tensor = inp_tensor.to(device=device)
6441*da0073e9SAndroid Build Coastguard Worker        t = t.to(device=device)
6442*da0073e9SAndroid Build Coastguard Worker        index = index.to(device=device)
6443*da0073e9SAndroid Build Coastguard Worker        out_gpu = inp_tensor.index_add(0, index, t)
6444*da0073e9SAndroid Build Coastguard Worker
6445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0)
6446*da0073e9SAndroid Build Coastguard Worker
6447*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to serialization test suite
6448*da0073e9SAndroid Build Coastguard Worker    def test_device_serialization(self, device):
6449*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, device=device)
6450*da0073e9SAndroid Build Coastguard Worker
6451*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
6452*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
6453*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
6454*da0073e9SAndroid Build Coastguard Worker            x_copy = torch.load(f)
6455*da0073e9SAndroid Build Coastguard Worker
6456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_copy, x)
6457*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(x_copy), type(x))
6458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_copy.device, x.device)
6459*da0073e9SAndroid Build Coastguard Worker
6460*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to serialization test suite
6461*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
6462*da0073e9SAndroid Build Coastguard Worker    def test_multidevice_serialization(self, devices):
6463*da0073e9SAndroid Build Coastguard Worker        x = [torch.randn(4, 4, device=devices[0]),
6464*da0073e9SAndroid Build Coastguard Worker             torch.randn(4, 4, device=devices[1])]
6465*da0073e9SAndroid Build Coastguard Worker
6466*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
6467*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
6468*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
6469*da0073e9SAndroid Build Coastguard Worker            x_copy = torch.load(f)
6470*da0073e9SAndroid Build Coastguard Worker
6471*da0073e9SAndroid Build Coastguard Worker        for original, cp in zip(x, x_copy):
6472*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cp, original)
6473*da0073e9SAndroid Build Coastguard Worker            self.assertIs(type(cp), type(original))
6474*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cp.device, original.device)
6475*da0073e9SAndroid Build Coastguard Worker
6476*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to data movement test suite
6477*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
6478*da0073e9SAndroid Build Coastguard Worker    def test_copy_noncontig(self, devices):
6479*da0073e9SAndroid Build Coastguard Worker        def do_test(d0, d1):
6480*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device=d0)
6481*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([0, 0, 0, 0, 0, 0], device=d1)
6482*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(x.dtype, y.dtype)
6483*da0073e9SAndroid Build Coastguard Worker
6484*da0073e9SAndroid Build Coastguard Worker            y[::2].copy_(x[::2])
6485*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, [1, 0, 3, 0, 5, 0])
6486*da0073e9SAndroid Build Coastguard Worker
6487*da0073e9SAndroid Build Coastguard Worker        do_test('cpu', devices[0])
6488*da0073e9SAndroid Build Coastguard Worker        do_test(devices[0], 'cpu')
6489*da0073e9SAndroid Build Coastguard Worker
6490*da0073e9SAndroid Build Coastguard Worker        if len(devices) > 1:
6491*da0073e9SAndroid Build Coastguard Worker            do_test(devices[0], devices[1])
6492*da0073e9SAndroid Build Coastguard Worker
6493*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
6494*da0073e9SAndroid Build Coastguard Worker    def test_type_conversions_same_device(self, devices):
6495*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device=devices[1])
6496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.int().device, torch.device(devices[1]))
6497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.type(torch.int).device, torch.device(devices[1]))
6498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.to(torch.int).device, torch.device(devices[1]))
6499*da0073e9SAndroid Build Coastguard Worker
6500*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double,
6501*da0073e9SAndroid Build Coastguard Worker                  torch.int8, torch.short, torch.int, torch.long,
6502*da0073e9SAndroid Build Coastguard Worker                  torch.uint8)
6503*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double,
6504*da0073e9SAndroid Build Coastguard Worker            torch.int8, torch.short, torch.int, torch.long,
6505*da0073e9SAndroid Build Coastguard Worker            torch.uint8)
6506*da0073e9SAndroid Build Coastguard Worker    def test_from_sequence(self, device, dtype):
6507*da0073e9SAndroid Build Coastguard Worker        seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
6508*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0, 20).resize_(5, 4)
6509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False)
6510*da0073e9SAndroid Build Coastguard Worker
6511*da0073e9SAndroid Build Coastguard Worker    # FIXME: moved to indexing test suite
6512*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
6513*da0073e9SAndroid Build Coastguard Worker    def test_advancedindex_mixed_cpu_devices(self, devices) -> None:
6514*da0073e9SAndroid Build Coastguard Worker        def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
6515*da0073e9SAndroid Build Coastguard Worker            # test getitem
6516*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[:, ia, None, ib, 0].cpu(),
6517*da0073e9SAndroid Build Coastguard Worker                             x.cpu()[:, ia.cpu(), None, ib.cpu(), 0])
6518*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[ia], x.cpu()[ia.cpu()])
6519*da0073e9SAndroid Build Coastguard Worker            # test setitem
6520*da0073e9SAndroid Build Coastguard Worker            x_clone1 = x.clone()
6521*da0073e9SAndroid Build Coastguard Worker            x_clone2 = x.clone()
6522*da0073e9SAndroid Build Coastguard Worker            first_shape = x[:, ia, None, ib, 0].shape
6523*da0073e9SAndroid Build Coastguard Worker            second_shape = x[ia].shape
6524*da0073e9SAndroid Build Coastguard Worker            x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1)
6525*da0073e9SAndroid Build Coastguard Worker            x_clone2[ia] = torch.randn(second_shape).to(x_clone2)
6526*da0073e9SAndroid Build Coastguard Worker
6527*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device('cpu')
6528*da0073e9SAndroid Build Coastguard Worker        for device in devices:
6529*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 4, 4, 4, 3)
6530*da0073e9SAndroid Build Coastguard Worker            ia = torch.tensor([0, 2, 1])
6531*da0073e9SAndroid Build Coastguard Worker            ib = torch.tensor([0, 2, 1])
6532*da0073e9SAndroid Build Coastguard Worker
6533*da0073e9SAndroid Build Coastguard Worker            # Index device tensor with cpu tensor
6534*da0073e9SAndroid Build Coastguard Worker            x = x.to(device)
6535*da0073e9SAndroid Build Coastguard Worker            ia = ia.to(cpu)
6536*da0073e9SAndroid Build Coastguard Worker            ib = ib.to(cpu)
6537*da0073e9SAndroid Build Coastguard Worker            test(x, ia, ib)
6538*da0073e9SAndroid Build Coastguard Worker
6539*da0073e9SAndroid Build Coastguard Worker            # Index device tensor with mixed cpu, device tensors
6540*da0073e9SAndroid Build Coastguard Worker            x = x.to(device)
6541*da0073e9SAndroid Build Coastguard Worker            ia = ia.to(cpu)
6542*da0073e9SAndroid Build Coastguard Worker            ib = ib.to(device)
6543*da0073e9SAndroid Build Coastguard Worker            test(x, ia, ib)
6544*da0073e9SAndroid Build Coastguard Worker
6545*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
6546*da0073e9SAndroid Build Coastguard Worker    def test_advancedindex_mixed_devices_error(self, devices) -> None:
6547*da0073e9SAndroid Build Coastguard Worker        def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
6548*da0073e9SAndroid Build Coastguard Worker            # test getitem
6549*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
6550*da0073e9SAndroid Build Coastguard Worker                value = x[:, ia, None, ib, 0]
6551*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
6552*da0073e9SAndroid Build Coastguard Worker                value = x[ib]
6553*da0073e9SAndroid Build Coastguard Worker
6554*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device('cpu')
6555*da0073e9SAndroid Build Coastguard Worker        for device in devices:
6556*da0073e9SAndroid Build Coastguard Worker            # Index cpu tensor with device tensor
6557*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 4, 4, 4, 3)
6558*da0073e9SAndroid Build Coastguard Worker            ia = torch.tensor([0, 2, 1]).to(device)
6559*da0073e9SAndroid Build Coastguard Worker            ib = torch.tensor([0, 2, 1]).to(device)
6560*da0073e9SAndroid Build Coastguard Worker            test(x, ia, ib)
6561*da0073e9SAndroid Build Coastguard Worker
6562*da0073e9SAndroid Build Coastguard Worker            # Index cpu tensor with mixed cpu, device tensors
6563*da0073e9SAndroid Build Coastguard Worker            x = x.to(cpu)
6564*da0073e9SAndroid Build Coastguard Worker            ia = ia.to(cpu)
6565*da0073e9SAndroid Build Coastguard Worker            ib = ib.to(device)
6566*da0073e9SAndroid Build Coastguard Worker            test(x, ia, ib)
6567*da0073e9SAndroid Build Coastguard Worker
6568*da0073e9SAndroid Build Coastguard Worker            if len(devices) > 1:
6569*da0073e9SAndroid Build Coastguard Worker                other_device = devices[0] if device == devices[1] else devices[1]
6570*da0073e9SAndroid Build Coastguard Worker
6571*da0073e9SAndroid Build Coastguard Worker                # Index device tensor with mixed cpu, device tensors on different devices
6572*da0073e9SAndroid Build Coastguard Worker                x = x.to(device)
6573*da0073e9SAndroid Build Coastguard Worker                ia = ia.to(cpu)
6574*da0073e9SAndroid Build Coastguard Worker                ib = ib.to(other_device)
6575*da0073e9SAndroid Build Coastguard Worker                test(x, ia, ib)
6576*da0073e9SAndroid Build Coastguard Worker
6577*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to data movement test suite
6578*da0073e9SAndroid Build Coastguard Worker    def test_copy_broadcast(self, device) -> None:
6579*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 5)
6580*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, device=device)
6581*da0073e9SAndroid Build Coastguard Worker        x.copy_(y)
6582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[3], y)
6583*da0073e9SAndroid Build Coastguard Worker
6584*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 5, device=device)
6585*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5)
6586*da0073e9SAndroid Build Coastguard Worker        x.copy_(y)
6587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[3], y)
6588*da0073e9SAndroid Build Coastguard Worker
6589*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to an elementwise ternary test suite
6590*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float32, torch.float64)
6591*da0073e9SAndroid Build Coastguard Worker    def test_clamp(self, device, dtype):
6592*da0073e9SAndroid Build Coastguard Worker        test_args = [
6593*da0073e9SAndroid Build Coastguard Worker            *product(
6594*da0073e9SAndroid Build Coastguard Worker                [(100, 50), (10, 64), (97,)],  # shape
6595*da0073e9SAndroid Build Coastguard Worker                (True, False),  # non-contiguous
6596*da0073e9SAndroid Build Coastguard Worker            )
6597*da0073e9SAndroid Build Coastguard Worker        ]
6598*da0073e9SAndroid Build Coastguard Worker
6599*da0073e9SAndroid Build Coastguard Worker        for shape, noncontig in test_args:
6600*da0073e9SAndroid Build Coastguard Worker            x = make_tensor(shape, device=device, dtype=dtype,
6601*da0073e9SAndroid Build Coastguard Worker                            noncontiguous=noncontig)
6602*da0073e9SAndroid Build Coastguard Worker            ub = make_tensor(shape, device=device, dtype=dtype,
6603*da0073e9SAndroid Build Coastguard Worker                             noncontiguous=noncontig)
6604*da0073e9SAndroid Build Coastguard Worker            lb = make_tensor(shape, device=device, dtype=dtype,
6605*da0073e9SAndroid Build Coastguard Worker                             noncontiguous=noncontig)
6606*da0073e9SAndroid Build Coastguard Worker
6607*da0073e9SAndroid Build Coastguard Worker            expect = x.max(lb).min(ub)
6608*da0073e9SAndroid Build Coastguard Worker            actual = x.clamp(lb, ub)
6609*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
6610*da0073e9SAndroid Build Coastguard Worker
6611*da0073e9SAndroid Build Coastguard Worker            expect = np.clip(x.cpu().numpy(), lb.cpu().numpy(), ub.cpu().numpy())
6612*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
6613*da0073e9SAndroid Build Coastguard Worker
6614*da0073e9SAndroid Build Coastguard Worker            expect = x.max(lb)
6615*da0073e9SAndroid Build Coastguard Worker            actual = x.clamp(min=lb)
6616*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
6617*da0073e9SAndroid Build Coastguard Worker
6618*da0073e9SAndroid Build Coastguard Worker            expect = x.min(ub)
6619*da0073e9SAndroid Build Coastguard Worker            actual = x.clamp(max=ub)
6620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
6621*da0073e9SAndroid Build Coastguard Worker
6622*da0073e9SAndroid Build Coastguard Worker            # Test broadcasting min & max
6623*da0073e9SAndroid Build Coastguard Worker            expect = x.max(lb[0]).min(ub[..., :1])
6624*da0073e9SAndroid Build Coastguard Worker            actual = x.clamp(lb[0], ub[..., :1])
6625*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
6626*da0073e9SAndroid Build Coastguard Worker
6627*da0073e9SAndroid Build Coastguard Worker            # Test broadcasting x
6628*da0073e9SAndroid Build Coastguard Worker            expect = x[..., :1].max(lb).min(ub)
6629*da0073e9SAndroid Build Coastguard Worker            actual = x[..., :1].clamp(lb, ub)
6630*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
6631*da0073e9SAndroid Build Coastguard Worker
6632*da0073e9SAndroid Build Coastguard Worker    def test_cuda_device_idx(self, device):
6633*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(3, device=device)
6634*da0073e9SAndroid Build Coastguard Worker        y = torch._efficientzerotensor(3, device=device)
6635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.device, y.device)
6636*da0073e9SAndroid Build Coastguard Worker
6637*da0073e9SAndroid Build Coastguard Worker# we implemented custom deallocation for subclasses, so it behooves
6638*da0073e9SAndroid Build Coastguard Worker# us to make sure all of these bits work.  We'll use __del__ to
6639*da0073e9SAndroid Build Coastguard Worker# track if objects die or not
6640*da0073e9SAndroid Build Coastguard Workerclass Tracker:
6641*da0073e9SAndroid Build Coastguard Worker    def __init__(self, marker):
6642*da0073e9SAndroid Build Coastguard Worker        self.marker = marker
6643*da0073e9SAndroid Build Coastguard Worker
6644*da0073e9SAndroid Build Coastguard Worker    @staticmethod
6645*da0073e9SAndroid Build Coastguard Worker    def make():
6646*da0073e9SAndroid Build Coastguard Worker        marker = [False]
6647*da0073e9SAndroid Build Coastguard Worker        return marker, Tracker(marker)
6648*da0073e9SAndroid Build Coastguard Worker
6649*da0073e9SAndroid Build Coastguard Worker    def __del__(self):
6650*da0073e9SAndroid Build Coastguard Worker        self.marker[0] = True
6651*da0073e9SAndroid Build Coastguard Worker
6652*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
6653*da0073e9SAndroid Build Coastguard Workerdef disable_gc():
6654*da0073e9SAndroid Build Coastguard Worker    if gc.isenabled():
6655*da0073e9SAndroid Build Coastguard Worker        try:
6656*da0073e9SAndroid Build Coastguard Worker            gc.disable()
6657*da0073e9SAndroid Build Coastguard Worker            yield
6658*da0073e9SAndroid Build Coastguard Worker        finally:
6659*da0073e9SAndroid Build Coastguard Worker            gc.enable()
6660*da0073e9SAndroid Build Coastguard Worker    else:
6661*da0073e9SAndroid Build Coastguard Worker        yield
6662*da0073e9SAndroid Build Coastguard Worker
6663*da0073e9SAndroid Build Coastguard Workerclass TestTorch(TestCase):
6664*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
6665*da0073e9SAndroid Build Coastguard Worker
6666*da0073e9SAndroid Build Coastguard Worker    def test_dir(self):
6667*da0073e9SAndroid Build Coastguard Worker        dir(torch)
6668*da0073e9SAndroid Build Coastguard Worker
6669*da0073e9SAndroid Build Coastguard Worker    def test_wildcard_import(self):
6670*da0073e9SAndroid Build Coastguard Worker        exec('from torch import *')
6671*da0073e9SAndroid Build Coastguard Worker
6672*da0073e9SAndroid Build Coastguard Worker    def test_newaxis_numpy_comparison(self):
6673*da0073e9SAndroid Build Coastguard Worker        def run_test(tensor, *idx):
6674*da0073e9SAndroid Build Coastguard Worker            npt = tensor.numpy()
6675*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor[idx], npt[idx])
6676*da0073e9SAndroid Build Coastguard Worker
6677*da0073e9SAndroid Build Coastguard Worker        # 1D Tensor Tests
6678*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 10)
6679*da0073e9SAndroid Build Coastguard Worker        cases = [
6680*da0073e9SAndroid Build Coastguard Worker            [None],
6681*da0073e9SAndroid Build Coastguard Worker            [None, None],
6682*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, None],
6683*da0073e9SAndroid Build Coastguard Worker            [None, Ellipsis],
6684*da0073e9SAndroid Build Coastguard Worker            [2, None],
6685*da0073e9SAndroid Build Coastguard Worker            [None, 2],
6686*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, None, 2],
6687*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, 2, None],
6688*da0073e9SAndroid Build Coastguard Worker            [2, Ellipsis, None],
6689*da0073e9SAndroid Build Coastguard Worker            [2, None, Ellipsis],
6690*da0073e9SAndroid Build Coastguard Worker            [None, 2, Ellipsis],
6691*da0073e9SAndroid Build Coastguard Worker            [None, Ellipsis, 2],
6692*da0073e9SAndroid Build Coastguard Worker        ]
6693*da0073e9SAndroid Build Coastguard Worker
6694*da0073e9SAndroid Build Coastguard Worker        for case in cases:
6695*da0073e9SAndroid Build Coastguard Worker            run_test(x, *case)
6696*da0073e9SAndroid Build Coastguard Worker
6697*da0073e9SAndroid Build Coastguard Worker        # 2D Tensor Tests
6698*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12).view(3, 4)
6699*da0073e9SAndroid Build Coastguard Worker        cases = [
6700*da0073e9SAndroid Build Coastguard Worker            [None],
6701*da0073e9SAndroid Build Coastguard Worker            [None, None],
6702*da0073e9SAndroid Build Coastguard Worker            [None, None, None],
6703*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, None],
6704*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, None, None],
6705*da0073e9SAndroid Build Coastguard Worker            [None, Ellipsis],
6706*da0073e9SAndroid Build Coastguard Worker            [None, Ellipsis, None],
6707*da0073e9SAndroid Build Coastguard Worker            [None, None, Ellipsis],
6708*da0073e9SAndroid Build Coastguard Worker            [2, None],
6709*da0073e9SAndroid Build Coastguard Worker            [2, None, Ellipsis],
6710*da0073e9SAndroid Build Coastguard Worker            [2, Ellipsis, None],
6711*da0073e9SAndroid Build Coastguard Worker            [None, 2, Ellipsis],
6712*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, 2, None],
6713*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, None, 2],
6714*da0073e9SAndroid Build Coastguard Worker            [None, Ellipsis, 2],
6715*da0073e9SAndroid Build Coastguard Worker            [1, 2, None],
6716*da0073e9SAndroid Build Coastguard Worker            [1, 2, Ellipsis, None],
6717*da0073e9SAndroid Build Coastguard Worker            [1, Ellipsis, 2, None],
6718*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, 1, None, 2],
6719*da0073e9SAndroid Build Coastguard Worker            [Ellipsis, 1, 2, None],
6720*da0073e9SAndroid Build Coastguard Worker            [1, None, 2, Ellipsis],
6721*da0073e9SAndroid Build Coastguard Worker            [None, 1, Ellipsis, 2],
6722*da0073e9SAndroid Build Coastguard Worker            [None, 1, 2, Ellipsis],
6723*da0073e9SAndroid Build Coastguard Worker        ]
6724*da0073e9SAndroid Build Coastguard Worker
6725*da0073e9SAndroid Build Coastguard Worker        for case in cases:
6726*da0073e9SAndroid Build Coastguard Worker            run_test(x, *case)
6727*da0073e9SAndroid Build Coastguard Worker
6728*da0073e9SAndroid Build Coastguard Worker    def _consecutive(self, size, start=1):
6729*da0073e9SAndroid Build Coastguard Worker        sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
6730*da0073e9SAndroid Build Coastguard Worker        sequence.add_(start - 1)
6731*da0073e9SAndroid Build Coastguard Worker        return sequence.resize_(*size)
6732*da0073e9SAndroid Build Coastguard Worker
6733*da0073e9SAndroid Build Coastguard Worker    def test_newindex(self):
6734*da0073e9SAndroid Build Coastguard Worker        reference = self._consecutive((3, 3, 3))
6735*da0073e9SAndroid Build Coastguard Worker        # This relies on __index__() being correct - but we have separate tests for that
6736*da0073e9SAndroid Build Coastguard Worker
6737*da0073e9SAndroid Build Coastguard Worker        def checkPartialAssign(index):
6738*da0073e9SAndroid Build Coastguard Worker            reference = torch.zeros(3, 3, 3)
6739*da0073e9SAndroid Build Coastguard Worker            reference[index] = self._consecutive((3, 3, 3))[index]
6740*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], atol=0, rtol=0)
6741*da0073e9SAndroid Build Coastguard Worker            reference[index] = 0
6742*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(reference, torch.zeros(3, 3, 3), atol=0, rtol=0)
6743*da0073e9SAndroid Build Coastguard Worker
6744*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign(0)
6745*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign(1)
6746*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign(2)
6747*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign((0, 1))
6748*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign((1, 2))
6749*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign((0, 2))
6750*da0073e9SAndroid Build Coastguard Worker        checkPartialAssign(torch.LongTensor((0, 2)))
6751*da0073e9SAndroid Build Coastguard Worker
6752*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6753*da0073e9SAndroid Build Coastguard Worker            reference[1, 1, 1, 1] = 1
6754*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6755*da0073e9SAndroid Build Coastguard Worker            reference[1, 1, 1, (1, 1)] = 1
6756*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6757*da0073e9SAndroid Build Coastguard Worker            reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
6758*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6759*da0073e9SAndroid Build Coastguard Worker            reference[0.0] = 1
6760*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
6761*da0073e9SAndroid Build Coastguard Worker            reference[0.0:2.0] = 1
6762*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6763*da0073e9SAndroid Build Coastguard Worker            reference[0.0, 0.0:2.0] = 1
6764*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6765*da0073e9SAndroid Build Coastguard Worker            reference[0.0, :, 0.0:2.0] = 1
6766*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6767*da0073e9SAndroid Build Coastguard Worker            reference[0.0, ..., 0.0:2.0] = 1
6768*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
6769*da0073e9SAndroid Build Coastguard Worker            reference[0.0, :, 0.0] = 1
6770*da0073e9SAndroid Build Coastguard Worker
6771*da0073e9SAndroid Build Coastguard Worker    # Test `torch._check*` functions
6772*da0073e9SAndroid Build Coastguard Worker    def test_check(self):
6773*da0073e9SAndroid Build Coastguard Worker        test_cases = [
6774*da0073e9SAndroid Build Coastguard Worker            # check function, expected error
6775*da0073e9SAndroid Build Coastguard Worker            (torch._check, RuntimeError),
6776*da0073e9SAndroid Build Coastguard Worker            (torch._check_index, IndexError),
6777*da0073e9SAndroid Build Coastguard Worker            (torch._check_value, ValueError),
6778*da0073e9SAndroid Build Coastguard Worker            (torch._check_type, TypeError),
6779*da0073e9SAndroid Build Coastguard Worker            (torch._check_not_implemented, NotImplementedError),
6780*da0073e9SAndroid Build Coastguard Worker        ]
6781*da0073e9SAndroid Build Coastguard Worker
6782*da0073e9SAndroid Build Coastguard Worker        for check_fn, expected_error in test_cases:
6783*da0073e9SAndroid Build Coastguard Worker            # cond=True should not raise an error
6784*da0073e9SAndroid Build Coastguard Worker            check_fn(True)
6785*da0073e9SAndroid Build Coastguard Worker
6786*da0073e9SAndroid Build Coastguard Worker            # Test default failure message for cond=False
6787*da0073e9SAndroid Build Coastguard Worker            default_message = 'Expected cond to be True'
6788*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, default_message):
6789*da0073e9SAndroid Build Coastguard Worker                check_fn(False)
6790*da0073e9SAndroid Build Coastguard Worker
6791*da0073e9SAndroid Build Coastguard Worker            # Test a simple failure message
6792*da0073e9SAndroid Build Coastguard Worker            message = 'message'
6793*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, message):
6794*da0073e9SAndroid Build Coastguard Worker                check_fn(False, lambda: message)
6795*da0073e9SAndroid Build Coastguard Worker
6796*da0073e9SAndroid Build Coastguard Worker            # Test message with tensor
6797*da0073e9SAndroid Build Coastguard Worker            def message():
6798*da0073e9SAndroid Build Coastguard Worker                return torch.arange(4)
6799*da0073e9SAndroid Build Coastguard Worker
6800*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
6801*da0073e9SAndroid Build Coastguard Worker                check_fn(False, message)
6802*da0073e9SAndroid Build Coastguard Worker
6803*da0073e9SAndroid Build Coastguard Worker            # Test format string message
6804*da0073e9SAndroid Build Coastguard Worker            def message():
6805*da0073e9SAndroid Build Coastguard Worker                return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
6806*da0073e9SAndroid Build Coastguard Worker
6807*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
6808*da0073e9SAndroid Build Coastguard Worker                check_fn(False, message)
6809*da0073e9SAndroid Build Coastguard Worker
6810*da0073e9SAndroid Build Coastguard Worker            # Test incorrect `cond` arg type
6811*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
6812*da0073e9SAndroid Build Coastguard Worker                check_fn('wrong type')
6813*da0073e9SAndroid Build Coastguard Worker
6814*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
6815*da0073e9SAndroid Build Coastguard Worker                check_fn(torch.tensor(True))
6816*da0073e9SAndroid Build Coastguard Worker
6817*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to indexing test suite
6818*da0073e9SAndroid Build Coastguard Worker    def test_index_add(self):
6819*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
6820*da0073e9SAndroid Build Coastguard Worker            for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
6821*da0073e9SAndroid Build Coastguard Worker                for other_sizes in ((), (4, 5)):
6822*da0073e9SAndroid Build Coastguard Worker                    for dtype in [torch.int, torch.long]:
6823*da0073e9SAndroid Build Coastguard Worker                        num_copy, num_dest = 3, 3
6824*da0073e9SAndroid Build Coastguard Worker                        dest = torch.randn(num_dest, *other_sizes, device=device)
6825*da0073e9SAndroid Build Coastguard Worker                        if not dest_contig:
6826*da0073e9SAndroid Build Coastguard Worker                            dest = make_tensor(dest.shape, device=device, dtype=dest.dtype, noncontiguous=True)
6827*da0073e9SAndroid Build Coastguard Worker                        src = torch.randn(num_copy, *other_sizes, device=device)
6828*da0073e9SAndroid Build Coastguard Worker                        if not src_contig:
6829*da0073e9SAndroid Build Coastguard Worker                            src = noncontiguous_like(src)
6830*da0073e9SAndroid Build Coastguard Worker                        idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy)
6831*da0073e9SAndroid Build Coastguard Worker                        if not index_contig:
6832*da0073e9SAndroid Build Coastguard Worker                            idx = noncontiguous_like(idx)
6833*da0073e9SAndroid Build Coastguard Worker                        # index_add_ without alpha argument
6834*da0073e9SAndroid Build Coastguard Worker                        dest2 = dest.clone()
6835*da0073e9SAndroid Build Coastguard Worker                        dest.index_add_(0, idx, src)
6836*da0073e9SAndroid Build Coastguard Worker                        for i in range(idx.size(0)):
6837*da0073e9SAndroid Build Coastguard Worker                            dest2[idx[i]] += src[i]
6838*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(dest, dest2)
6839*da0073e9SAndroid Build Coastguard Worker                        # index_add_ with alpha argument
6840*da0073e9SAndroid Build Coastguard Worker                        dest2 = dest.clone()
6841*da0073e9SAndroid Build Coastguard Worker                        dest.index_add_(0, idx, src, alpha=2)
6842*da0073e9SAndroid Build Coastguard Worker                        for i in range(idx.size(0)):
6843*da0073e9SAndroid Build Coastguard Worker                            dest2[idx[i]] += src[i] * 2
6844*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(dest, dest2)
6845*da0073e9SAndroid Build Coastguard Worker
6846*da0073e9SAndroid Build Coastguard Worker    # FIXME: resolve comment below and move this to indexing test suite
6847*da0073e9SAndroid Build Coastguard Worker    # add coverage for issue with atomic add that appeared only for
6848*da0073e9SAndroid Build Coastguard Worker    # specific dtypes on cuda:
6849*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/29153
6850*da0073e9SAndroid Build Coastguard Worker    def test_index_add_all_dtypes(self):
6851*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
6852*da0073e9SAndroid Build Coastguard Worker            for dtype in get_all_math_dtypes(device):
6853*da0073e9SAndroid Build Coastguard Worker                for idx_dtype in [torch.int, torch.long]:
6854*da0073e9SAndroid Build Coastguard Worker                    size = [5, 5]
6855*da0073e9SAndroid Build Coastguard Worker                    if dtype.is_floating_point or dtype.is_complex:
6856*da0073e9SAndroid Build Coastguard Worker                        tensor = torch.rand(size, dtype=dtype, device=device)
6857*da0073e9SAndroid Build Coastguard Worker                    elif dtype.is_signed:
6858*da0073e9SAndroid Build Coastguard Worker                        tensor = torch.randint(-5, 15, size, dtype=dtype, device=device)
6859*da0073e9SAndroid Build Coastguard Worker                    else:
6860*da0073e9SAndroid Build Coastguard Worker                        tensor = torch.randint(0, 10, size, dtype=dtype, device=device)
6861*da0073e9SAndroid Build Coastguard Worker
6862*da0073e9SAndroid Build Coastguard Worker                    # index_add calls atomicAdd on cuda.
6863*da0073e9SAndroid Build Coastguard Worker                    zeros = torch.zeros(size, dtype=dtype, device=device)
6864*da0073e9SAndroid Build Coastguard Worker
6865*da0073e9SAndroid Build Coastguard Worker                    added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor)
6866*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(added, tensor)
6867*da0073e9SAndroid Build Coastguard Worker
6868*da0073e9SAndroid Build Coastguard Worker                    added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1)
6869*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(added, -tensor)
6870*da0073e9SAndroid Build Coastguard Worker
6871*da0073e9SAndroid Build Coastguard Worker    @unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
6872*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6873*da0073e9SAndroid Build Coastguard Worker    def test_index_add_correctness(self):
6874*da0073e9SAndroid Build Coastguard Worker        # Check whether index_add can get correct result when
6875*da0073e9SAndroid Build Coastguard Worker        # alpha is 1, and dtype of index is torch.long,
6876*da0073e9SAndroid Build Coastguard Worker        # i.e., using scatter_add
6877*da0073e9SAndroid Build Coastguard Worker        def helper(dim, dtype, device, size_result, size_source):
6878*da0073e9SAndroid Build Coastguard Worker            tensor = torch.zeros(size_result, dtype=dtype, device=device)
6879*da0073e9SAndroid Build Coastguard Worker            index = torch.randint(0, size_result[dim], (size_source[dim],),
6880*da0073e9SAndroid Build Coastguard Worker                                  dtype=torch.long, device=device)
6881*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
6882*da0073e9SAndroid Build Coastguard Worker                source = torch.rand(size_source, dtype=dtype, device=device)
6883*da0073e9SAndroid Build Coastguard Worker            elif dtype.is_signed:
6884*da0073e9SAndroid Build Coastguard Worker                source = torch.randint(-2, 5, size_source, dtype=dtype, device=device)
6885*da0073e9SAndroid Build Coastguard Worker            else:
6886*da0073e9SAndroid Build Coastguard Worker                source = torch.randint(0, 5, size_source, dtype=dtype, device=device)
6887*da0073e9SAndroid Build Coastguard Worker
6888*da0073e9SAndroid Build Coastguard Worker            ref_out = tensor.index_add(dim, index, source, alpha=2.) / 2.
6889*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_out.to(dtype=dtype)
6890*da0073e9SAndroid Build Coastguard Worker            out = tensor.index_add(dim, index, source)
6891*da0073e9SAndroid Build Coastguard Worker            if device == 'cuda':
6892*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, ref_out, atol=1e-2, rtol=1e-2)
6893*da0073e9SAndroid Build Coastguard Worker            else:
6894*da0073e9SAndroid Build Coastguard Worker                # scatter_add uses fp32 as accumulate type, while index_add doesn't.
6895*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, ref_out.to(dtype=dtype), atol=1e-2, rtol=1e-2)
6896*da0073e9SAndroid Build Coastguard Worker
6897*da0073e9SAndroid Build Coastguard Worker        for dim in [-1, -2, -3]:
6898*da0073e9SAndroid Build Coastguard Worker            for dtype in all_types_and_complex_and(torch.half, torch.bfloat16):
6899*da0073e9SAndroid Build Coastguard Worker                for device in get_all_device_types():
6900*da0073e9SAndroid Build Coastguard Worker                    for size in [(2, 512, 256), (5, 256, 256)]:
6901*da0073e9SAndroid Build Coastguard Worker                        helper(dim, dtype, device, size, size)
6902*da0073e9SAndroid Build Coastguard Worker
6903*da0073e9SAndroid Build Coastguard Worker                # Check bound
6904*da0073e9SAndroid Build Coastguard Worker                result = torch.zeros(1, 512, 256, dtype=dtype)
6905*da0073e9SAndroid Build Coastguard Worker                source = torch.ones(1, 512, 256, dtype=dtype)
6906*da0073e9SAndroid Build Coastguard Worker                index = torch.ones(257).to(dtype=torch.long)
6907*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source))
6908*da0073e9SAndroid Build Coastguard Worker                index = (torch.ones(256) * 257).to(dtype=torch.long)
6909*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source))
6910*da0073e9SAndroid Build Coastguard Worker
6911*da0073e9SAndroid Build Coastguard Worker    def test_index_add_cornercase(self):
6912*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
6913*da0073e9SAndroid Build Coastguard Worker            dest = torch.randn((), device=device)
6914*da0073e9SAndroid Build Coastguard Worker            index = torch.tensor([0], device=device)
6915*da0073e9SAndroid Build Coastguard Worker            source = torch.randn(1, 1, 1, device=device)
6916*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6917*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
6918*da0073e9SAndroid Build Coastguard Worker                r"source tensor shape must match self tensor shape, excluding the specified dimension",
6919*da0073e9SAndroid Build Coastguard Worker            ):
6920*da0073e9SAndroid Build Coastguard Worker                dest.index_add(0, index, source)
6921*da0073e9SAndroid Build Coastguard Worker
6922*da0073e9SAndroid Build Coastguard Worker    def test_linspace_logspace(self):
6923*da0073e9SAndroid Build Coastguard Worker        # Ensure the output does not require grad regardless of inputs requiring gard or not.
6924*da0073e9SAndroid Build Coastguard Worker        # The output of factory functions should not be part of any computational graph.
6925*da0073e9SAndroid Build Coastguard Worker        start = 0.0
6926*da0073e9SAndroid Build Coastguard Worker        end = 3.0
6927*da0073e9SAndroid Build Coastguard Worker
6928*da0073e9SAndroid Build Coastguard Worker        for step in [0, 1, 2]:
6929*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6930*da0073e9SAndroid Build Coastguard Worker                torch.linspace(
6931*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(start, requires_grad=True),
6932*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(end, requires_grad=True), step
6933*da0073e9SAndroid Build Coastguard Worker                ).requires_grad
6934*da0073e9SAndroid Build Coastguard Worker            )
6935*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.linspace(torch.tensor(start, requires_grad=True), end, step).requires_grad)
6936*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.linspace(start, torch.tensor(end, requires_grad=True), step).requires_grad)
6937*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6938*da0073e9SAndroid Build Coastguard Worker                torch.logspace(
6939*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(start, requires_grad=True),
6940*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(end, requires_grad=True), step
6941*da0073e9SAndroid Build Coastguard Worker                ).requires_grad
6942*da0073e9SAndroid Build Coastguard Worker            )
6943*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.logspace(torch.tensor(start, requires_grad=True), end, step).requires_grad)
6944*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.logspace(start, torch.tensor(end, requires_grad=True), step).requires_grad)
6945*da0073e9SAndroid Build Coastguard Worker
6946*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to shape ops test suite
6947*da0073e9SAndroid Build Coastguard Worker    def test_unflatten(self):
6948*da0073e9SAndroid Build Coastguard Worker        # test args: tensor, int, sizes
6949*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([]).unflatten(0, (0, 1)), torch.empty(0, 1))
6950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([1]).unflatten(0, (1, 1)), torch.tensor([[1]]))
6951*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (2, 2)), torch.tensor([[1, 2], [3, 4]]))
6952*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, [2, 2]), torch.tensor([[1, 2], [3, 4]]))
6953*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, torch.Size([2, 2])), torch.tensor([[1, 2], [3, 4]]))
6954*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(2, 10).unflatten(1, (5, 2)), torch.ones(2, 5, 2))
6955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (-1, 2)),
6956*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([[1, 2], [3, 4]]))
6957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(2, 10).unflatten(1, (5, -1)),
6958*da0073e9SAndroid Build Coastguard Worker                         torch.ones(2, 5, 2))
6959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(2, 10).unflatten(1, (-1,)),
6960*da0073e9SAndroid Build Coastguard Worker                         torch.ones(2, 10))
6961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(2, 3 * 4 * 5 * 6).unflatten(1, (3, 4, -1, 6)),
6962*da0073e9SAndroid Build Coastguard Worker                         torch.ones(2, 3, 4, 5, 6))
6963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(2, 0, 2).unflatten(1, (3, -1, 4, 5)),
6964*da0073e9SAndroid Build Coastguard Worker                         torch.ones(2, 3, 0, 4, 5, 2))
6965*da0073e9SAndroid Build Coastguard Worker
6966*da0073e9SAndroid Build Coastguard Worker        # test invalid args: tensor, str, sizes
6967*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
6968*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1]).unflatten('A', (1, 1))
6969*da0073e9SAndroid Build Coastguard Worker
6970*da0073e9SAndroid Build Coastguard Worker        # test invalid args: tensor, str, namedshape
6971*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Name 'A' not found in Tensor\[None\]."):
6972*da0073e9SAndroid Build Coastguard Worker            torch.ones(4).unflatten('A', (('A', 2), ('B', 2)))
6973*da0073e9SAndroid Build Coastguard Worker
6974*da0073e9SAndroid Build Coastguard Worker        # test other invalid arguments
6975*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"sizes must be non-empty"):
6976*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1]).unflatten(0, [])
6977*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
6978*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1]).unflatten(0, [2, 2])
6979*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
6980*da0073e9SAndroid Build Coastguard Worker            torch.tensor(1).unflatten(0, [0])
6981*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
6982*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 10).unflatten(1, (-1, -1))
6983*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
6984*da0073e9SAndroid Build Coastguard Worker                                    r"Provided sizes \[-1, 4\] don't multiply up to the size of dim 1 \(10\)"):
6985*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 10).unflatten(1, (-1, 4))
6986*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
6987*da0073e9SAndroid Build Coastguard Worker                                    r"the unspecified dimension size -1 can be any value and is ambiguous"):
6988*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 0).unflatten(1, (2, -1, 0))
6989*da0073e9SAndroid Build Coastguard Worker
6990*da0073e9SAndroid Build Coastguard Worker    # Test that warnings generated from C++ are translated to the correct type
6991*da0073e9SAndroid Build Coastguard Worker    def test_warn_types(self):
6992*da0073e9SAndroid Build Coastguard Worker        test_cases = [
6993*da0073e9SAndroid Build Coastguard Worker            # function, warning type, message
6994*da0073e9SAndroid Build Coastguard Worker            (torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
6995*da0073e9SAndroid Build Coastguard Worker            (torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
6996*da0073e9SAndroid Build Coastguard Worker        ]
6997*da0073e9SAndroid Build Coastguard Worker
6998*da0073e9SAndroid Build Coastguard Worker        for fn, warning_type, message in test_cases:
6999*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
7000*da0073e9SAndroid Build Coastguard Worker                warnings.resetwarnings()
7001*da0073e9SAndroid Build Coastguard Worker                warnings.filterwarnings('always', category=warning_type)
7002*da0073e9SAndroid Build Coastguard Worker                fn()
7003*da0073e9SAndroid Build Coastguard Worker
7004*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
7005*da0073e9SAndroid Build Coastguard Worker                warning = w[0].message
7006*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
7007*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(re.search(
7008*da0073e9SAndroid Build Coastguard Worker                    message,
7009*da0073e9SAndroid Build Coastguard Worker                    str(warning)))
7010*da0073e9SAndroid Build Coastguard Worker
7011*da0073e9SAndroid Build Coastguard Worker    def test_structseq_repr(self):
7012*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(250).reshape(5, 5, 10)
7013*da0073e9SAndroid Build Coastguard Worker        expected = """
7014*da0073e9SAndroid Build Coastguard Worker        torch.return_types.max(
7015*da0073e9SAndroid Build Coastguard Worker        values=tensor([[ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49],
7016*da0073e9SAndroid Build Coastguard Worker                [ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
7017*da0073e9SAndroid Build Coastguard Worker                [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
7018*da0073e9SAndroid Build Coastguard Worker                [190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
7019*da0073e9SAndroid Build Coastguard Worker                [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]),
7020*da0073e9SAndroid Build Coastguard Worker        indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7021*da0073e9SAndroid Build Coastguard Worker                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7022*da0073e9SAndroid Build Coastguard Worker                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7023*da0073e9SAndroid Build Coastguard Worker                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
7024*da0073e9SAndroid Build Coastguard Worker                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))"""
7025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip())
7026*da0073e9SAndroid Build Coastguard Worker
7027*da0073e9SAndroid Build Coastguard Worker    def test_is_same_size(self):
7028*da0073e9SAndroid Build Coastguard Worker        t1 = torch.empty(3, 4, 9, 10)
7029*da0073e9SAndroid Build Coastguard Worker        t2 = torch.empty(3, 4)
7030*da0073e9SAndroid Build Coastguard Worker        t3 = torch.empty(1, 9, 3, 3)
7031*da0073e9SAndroid Build Coastguard Worker        t4 = torch.empty(3, 4, 9, 10)
7032*da0073e9SAndroid Build Coastguard Worker
7033*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_same_size(t2))
7034*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_same_size(t3))
7035*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t1.is_same_size(t4))
7036*da0073e9SAndroid Build Coastguard Worker
7037*da0073e9SAndroid Build Coastguard Worker        nt1 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)])
7038*da0073e9SAndroid Build Coastguard Worker        nt2 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(2, 4), torch.ones(2, 4)])
7039*da0073e9SAndroid Build Coastguard Worker        nt3 = torch.nested.nested_tensor([torch.ones(2, 4, 5), torch.ones(2, 6, 5)])
7040*da0073e9SAndroid Build Coastguard Worker        nt4 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)])
7041*da0073e9SAndroid Build Coastguard Worker
7042*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nt1.is_same_size(nt2))
7043*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nt1.is_same_size(nt3))
7044*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt1.is_same_size(nt4))
7045*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."):
7046*da0073e9SAndroid Build Coastguard Worker            t1.is_same_size(nt1)
7047*da0073e9SAndroid Build Coastguard Worker
7048*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."):
7049*da0073e9SAndroid Build Coastguard Worker            nt1.is_same_size(t1)
7050*da0073e9SAndroid Build Coastguard Worker
7051*da0073e9SAndroid Build Coastguard Worker    def test_tensor_set(self):
7052*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([])
7053*da0073e9SAndroid Build Coastguard Worker        t2 = torch.empty(3, 4, 9, 10).uniform_()
7054*da0073e9SAndroid Build Coastguard Worker        t1.set_(t2)
7055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7056*da0073e9SAndroid Build Coastguard Worker        size = torch.Size([9, 3, 4, 10])
7057*da0073e9SAndroid Build Coastguard Worker        t1.set_(t2.storage(), 0, size)
7058*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.size(), size)
7059*da0073e9SAndroid Build Coastguard Worker        t1.set_(t2.storage(), 0, tuple(size))
7060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.size(), size)
7061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.stride(), (120, 40, 10, 1))
7062*da0073e9SAndroid Build Coastguard Worker        stride = (10, 360, 90, 1)
7063*da0073e9SAndroid Build Coastguard Worker        t1.set_(t2.storage(), 0, size, stride)
7064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.stride(), stride)
7065*da0073e9SAndroid Build Coastguard Worker        t1.set_(t2.storage(), 0, size=size, stride=stride)
7066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.size(), size)
7067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.stride(), stride)
7068*da0073e9SAndroid Build Coastguard Worker
7069*da0073e9SAndroid Build Coastguard Worker        # test argument names
7070*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([])
7071*da0073e9SAndroid Build Coastguard Worker        # 1. case when source is tensor
7072*da0073e9SAndroid Build Coastguard Worker        t1.set_(source=t2)
7073*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7074*da0073e9SAndroid Build Coastguard Worker        # 2. case when source is storage
7075*da0073e9SAndroid Build Coastguard Worker        t1.set_(source=t2.storage())
7076*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7077*da0073e9SAndroid Build Coastguard Worker        # 3. case when source is storage, and other args also specified
7078*da0073e9SAndroid Build Coastguard Worker        t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride)
7079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.size(), size)
7080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.stride(), stride)
7081*da0073e9SAndroid Build Coastguard Worker
7082*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([True, True], dtype=torch.bool)
7083*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([False, False], dtype=torch.bool)
7084*da0073e9SAndroid Build Coastguard Worker        t1.set_(t2)
7085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
7086*da0073e9SAndroid Build Coastguard Worker
7087*da0073e9SAndroid Build Coastguard Worker    def test_tensor_set_errors(self):
7088*da0073e9SAndroid Build Coastguard Worker        f_cpu = torch.randn((2, 3), dtype=torch.float32)
7089*da0073e9SAndroid Build Coastguard Worker        d_cpu = torch.randn((2, 3), dtype=torch.float64)
7090*da0073e9SAndroid Build Coastguard Worker
7091*da0073e9SAndroid Build Coastguard Worker        # change dtype
7092*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
7093*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError,
7094*da0073e9SAndroid Build Coastguard Worker                          lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride()))
7095*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu))
7096*da0073e9SAndroid Build Coastguard Worker
7097*da0073e9SAndroid Build Coastguard Worker        # change device
7098*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7099*da0073e9SAndroid Build Coastguard Worker            f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda')
7100*da0073e9SAndroid Build Coastguard Worker
7101*da0073e9SAndroid Build Coastguard Worker            # cpu -> cuda
7102*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage()))
7103*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError,
7104*da0073e9SAndroid Build Coastguard Worker                              lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride()))
7105*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda))
7106*da0073e9SAndroid Build Coastguard Worker
7107*da0073e9SAndroid Build Coastguard Worker            # cuda -> cpu
7108*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage()))
7109*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError,
7110*da0073e9SAndroid Build Coastguard Worker                              lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride()))
7111*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu))
7112*da0073e9SAndroid Build Coastguard Worker
7113*da0073e9SAndroid Build Coastguard Worker    # FIXME: move this test test_testing.py (along with allclose testing)
7114*da0073e9SAndroid Build Coastguard Worker    # NOTE: test_equal will be deprecated in favor of torch.testing.assert_close
7115*da0073e9SAndroid Build Coastguard Worker    #   once torch.testing is out of beta
7116*da0073e9SAndroid Build Coastguard Worker    def test_equal(self):
7117*da0073e9SAndroid Build Coastguard Worker        devices = [torch.cpu, torch.cuda]
7118*da0073e9SAndroid Build Coastguard Worker        for device in ["cpu", "cuda"]:
7119*da0073e9SAndroid Build Coastguard Worker            if device == "cuda" and not torch.cuda.is_available():
7120*da0073e9SAndroid Build Coastguard Worker                continue
7121*da0073e9SAndroid Build Coastguard Worker
7122*da0073e9SAndroid Build Coastguard Worker            # Contiguous, 1D
7123*da0073e9SAndroid Build Coastguard Worker            t1 = torch.tensor((3., 4., 9., 10.), device=device)
7124*da0073e9SAndroid Build Coastguard Worker            t2 = t1.contiguous()
7125*da0073e9SAndroid Build Coastguard Worker            t3 = torch.tensor((1., 9., 3., 10.), device=device)
7126*da0073e9SAndroid Build Coastguard Worker            t4 = torch.tensor((3., 4., 9.), device=device)
7127*da0073e9SAndroid Build Coastguard Worker            t5 = torch.tensor([], device=device)
7128*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(t1.equal(t2))
7129*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(t1.equal(t3))
7130*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(t1.equal(t4))
7131*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(t1.equal(t5))
7132*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(t1, t2))
7133*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(t1, t3))
7134*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(t1, t4))
7135*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(t1, t5))
7136*da0073e9SAndroid Build Coastguard Worker
7137*da0073e9SAndroid Build Coastguard Worker            # Non contiguous, 2D
7138*da0073e9SAndroid Build Coastguard Worker            s = torch.tensor(((1, 2, 3, 4), (5, 6, 7, 8)), device=device)
7139*da0073e9SAndroid Build Coastguard Worker            s1 = s[:, 1:3]
7140*da0073e9SAndroid Build Coastguard Worker            s2 = s1.clone()
7141*da0073e9SAndroid Build Coastguard Worker            s3 = torch.tensor(((2, 3), (6, 7)), device=device)
7142*da0073e9SAndroid Build Coastguard Worker            s4 = torch.tensor(((0, 0), (0, 0)), device=device)
7143*da0073e9SAndroid Build Coastguard Worker
7144*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(s1.is_contiguous())
7145*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s1.equal(s2))
7146*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s1.equal(s3))
7147*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(s1.equal(s4))
7148*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(s1, s2))
7149*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(s1, s3))
7150*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(s1, s4))
7151*da0073e9SAndroid Build Coastguard Worker
7152*da0073e9SAndroid Build Coastguard Worker            # Different dtypes
7153*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor((1, 2, 3), dtype=torch.float, device=device)
7154*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor((1, 2, 3), dtype=torch.int, device=device)
7155*da0073e9SAndroid Build Coastguard Worker            z = torch.tensor((1, -1), dtype=torch.int, device=device)
7156*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(x, y))
7157*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(z, x))
7158*da0073e9SAndroid Build Coastguard Worker
7159*da0073e9SAndroid Build Coastguard Worker            # Fast path test: tensor flags, like neg and conj
7160*da0073e9SAndroid Build Coastguard Worker            neg_0 = torch.tensor((1, 2, 3), dtype=torch.float, device=device)
7161*da0073e9SAndroid Build Coastguard Worker            neg_1 = neg_0._neg_view()
7162*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(neg_1.is_neg())
7163*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(neg_0.data_ptr(), neg_1.data_ptr())
7164*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(neg_0.storage_offset(), neg_1.storage_offset())
7165*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(neg_0.stride(), neg_1.stride())
7166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(neg_0.size(), neg_1.size())
7167*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(neg_0, neg_1))
7168*da0073e9SAndroid Build Coastguard Worker            # FIXME: Disable the following check due to the inductor failure
7169*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/issues/100340 and
7170*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/98175
7171*da0073e9SAndroid Build Coastguard Worker            if not TEST_WITH_TORCHINDUCTOR:
7172*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.equal(neg_0, neg_1._neg_view()))
7173*da0073e9SAndroid Build Coastguard Worker
7174*da0073e9SAndroid Build Coastguard Worker            conj_0 = torch.tensor([1.0 + 2.0j, 2.0 + 1.0j], device=device)
7175*da0073e9SAndroid Build Coastguard Worker            conj_1 = conj_0.conj()
7176*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(conj_1.is_conj())
7177*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conj_0.data_ptr(), conj_1.data_ptr())
7178*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conj_0.storage_offset(), conj_1.storage_offset())
7179*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conj_0.stride(), conj_1.stride())
7180*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conj_0.size(), conj_1.size())
7181*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(conj_0, conj_1))
7182*da0073e9SAndroid Build Coastguard Worker            # FIXME: Disable the following check due to the inductor failure
7183*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/issues/100340 and
7184*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/98175
7185*da0073e9SAndroid Build Coastguard Worker            if not TEST_WITH_TORCHINDUCTOR:
7186*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.equal(conj_0, conj_1.conj()))
7187*da0073e9SAndroid Build Coastguard Worker
7188*da0073e9SAndroid Build Coastguard Worker            # Fast path test: two tensors share the same storage, but different dtype
7189*da0073e9SAndroid Build Coastguard Worker            s_0 = torch.rand((2, 3), dtype=torch.float, device=device)
7190*da0073e9SAndroid Build Coastguard Worker            s_1 = s_0.view(dtype=torch.int32)
7191*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s_0.data_ptr(), s_1.data_ptr())
7192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s_0.storage_offset(), s_1.storage_offset())
7193*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s_0.stride(), s_1.stride())
7194*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s_0.size(), s_1.size())
7195*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(s_0, s_1))
7196*da0073e9SAndroid Build Coastguard Worker
7197*da0073e9SAndroid Build Coastguard Worker            # Fast path test: two tensors share the same storage, but different strides
7198*da0073e9SAndroid Build Coastguard Worker            t_0 = torch.rand((2, 3), dtype=torch.float, device=device)
7199*da0073e9SAndroid Build Coastguard Worker            t_1 = t_0.t()
7200*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t_0.data_ptr(), t_1.data_ptr())
7201*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t_0.storage_offset(), t_1.storage_offset())
7202*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(t_0.stride(), t_1.stride())
7203*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(t_0.size(), t_1.size())
7204*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(t_0, t_1))
7205*da0073e9SAndroid Build Coastguard Worker
7206*da0073e9SAndroid Build Coastguard Worker            # Fast path: tensor containing `nan` is not equal to self
7207*da0073e9SAndroid Build Coastguard Worker            for dtype in floating_and_complex_types():
7208*da0073e9SAndroid Build Coastguard Worker                t = torch.tensor([1., float('nan')], dtype=dtype)
7209*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.equal(t, t))
7210*da0073e9SAndroid Build Coastguard Worker
7211*da0073e9SAndroid Build Coastguard Worker    def test_element_size(self):
7212*da0073e9SAndroid Build Coastguard Worker        byte = torch.ByteStorage().element_size()
7213*da0073e9SAndroid Build Coastguard Worker        char = torch.CharStorage().element_size()
7214*da0073e9SAndroid Build Coastguard Worker        short = torch.ShortStorage().element_size()
7215*da0073e9SAndroid Build Coastguard Worker        int = torch.IntStorage().element_size()
7216*da0073e9SAndroid Build Coastguard Worker        long = torch.LongStorage().element_size()
7217*da0073e9SAndroid Build Coastguard Worker        float = torch.FloatStorage().element_size()
7218*da0073e9SAndroid Build Coastguard Worker        double = torch.DoubleStorage().element_size()
7219*da0073e9SAndroid Build Coastguard Worker        bool = torch.BoolStorage().element_size()
7220*da0073e9SAndroid Build Coastguard Worker        bfloat16 = torch.BFloat16Storage().element_size()
7221*da0073e9SAndroid Build Coastguard Worker        complexfloat = torch.ComplexFloatStorage().element_size()
7222*da0073e9SAndroid Build Coastguard Worker        complexdouble = torch.ComplexDoubleStorage().element_size()
7223*da0073e9SAndroid Build Coastguard Worker
7224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byte, torch.ByteTensor().element_size())
7225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byte, torch.ByteTensor().itemsize)
7226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(char, torch.CharTensor().element_size())
7227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(char, torch.CharTensor().itemsize)
7228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(short, torch.ShortTensor().element_size())
7229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(short, torch.ShortTensor().itemsize)
7230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(int, torch.IntTensor().element_size())
7231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(int, torch.IntTensor().itemsize)
7232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(long, torch.LongTensor().element_size())
7233*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(long, torch.LongTensor().itemsize)
7234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float, torch.FloatTensor().element_size())
7235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float, torch.FloatTensor().itemsize)
7236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double, torch.DoubleTensor().element_size())
7237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double, torch.DoubleTensor().itemsize)
7238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bool, torch.BoolTensor().element_size())
7239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bool, torch.BoolTensor().itemsize)
7240*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).element_size())
7241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).itemsize)
7242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).element_size())
7243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).itemsize)
7244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).element_size())
7245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).itemsize)
7246*da0073e9SAndroid Build Coastguard Worker
7247*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(byte, 0)
7248*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(char, 0)
7249*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(short, 0)
7250*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(int, 0)
7251*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(long, 0)
7252*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(float, 0)
7253*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(double, 0)
7254*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(bool, 0)
7255*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(bfloat16, 0)
7256*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(complexfloat, 0)
7257*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(complexdouble, 0)
7258*da0073e9SAndroid Build Coastguard Worker
7259*da0073e9SAndroid Build Coastguard Worker        # These tests are portable, not necessarily strict for your system.
7260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byte, 1)
7261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(char, 1)
7262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bool, 1)
7263*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(short, 2)
7264*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(int, 2)
7265*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(int, short)
7266*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(long, 4)
7267*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(long, int)
7268*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(double, float)
7269*da0073e9SAndroid Build Coastguard Worker
7270*da0073e9SAndroid Build Coastguard Worker    def test_permute(self):
7271*da0073e9SAndroid Build Coastguard Worker        orig = [1, 2, 3, 4, 5, 6, 7]
7272*da0073e9SAndroid Build Coastguard Worker        perm = torch.randperm(7).tolist()
7273*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(*orig).fill_(0)
7274*da0073e9SAndroid Build Coastguard Worker        new = [i - 1 for i in x.permute(*perm).size()]
7275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(perm, new)
7276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.size(), orig)
7277*da0073e9SAndroid Build Coastguard Worker
7278*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7279*da0073e9SAndroid Build Coastguard Worker    def test_reversed(self):
7280*da0073e9SAndroid Build Coastguard Worker        val = torch.arange(0, 10)
7281*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reversed(val), torch.arange(9, -1, -1))
7282*da0073e9SAndroid Build Coastguard Worker
7283*da0073e9SAndroid Build Coastguard Worker        val = torch.arange(1, 10).view(3, 3)
7284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reversed(val), torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]]))
7285*da0073e9SAndroid Build Coastguard Worker
7286*da0073e9SAndroid Build Coastguard Worker        val = torch.tensor(42)
7287*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reversed(val), torch.tensor(42))
7288*da0073e9SAndroid Build Coastguard Worker
7289*da0073e9SAndroid Build Coastguard Worker    def test_contains(self):
7290*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 10)
7291*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4 in x, True)
7292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(12 in x, False)
7293*da0073e9SAndroid Build Coastguard Worker
7294*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1, 10).view(3, 3)
7295*da0073e9SAndroid Build Coastguard Worker        val = torch.arange(1, 4)
7296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(val in x, True)
7297*da0073e9SAndroid Build Coastguard Worker        val += 10
7298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(val in x, False)
7299*da0073e9SAndroid Build Coastguard Worker
7300*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
7301*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7302*da0073e9SAndroid Build Coastguard Worker            f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {str}.",
7303*da0073e9SAndroid Build Coastguard Worker            lambda: "foo" in x)
7304*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
7305*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7306*da0073e9SAndroid Build Coastguard Worker            f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type([1, 2])}.",
7307*da0073e9SAndroid Build Coastguard Worker            lambda: [1, 2] in x)
7308*da0073e9SAndroid Build Coastguard Worker
7309*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7310*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_parameter(self):
7311*da0073e9SAndroid Build Coastguard Worker        from copy import deepcopy
7312*da0073e9SAndroid Build Coastguard Worker        l = torch.nn.Linear(10, 1)
7313*da0073e9SAndroid Build Coastguard Worker        s = l.state_dict(keep_vars=True)
7314*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nn.Parameter, type(s['weight']))
7315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nn.Parameter, type(s['bias']))
7316*da0073e9SAndroid Build Coastguard Worker
7317*da0073e9SAndroid Build Coastguard Worker        s2 = deepcopy(s)
7318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nn.Parameter, type(s2['weight']))
7319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nn.Parameter, type(s2['bias']))
7320*da0073e9SAndroid Build Coastguard Worker
7321*da0073e9SAndroid Build Coastguard Worker    def test_pickle(self):
7322*da0073e9SAndroid Build Coastguard Worker        import pickle
7323*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5)
7324*da0073e9SAndroid Build Coastguard Worker        serialized = pickle.dumps(a)
7325*da0073e9SAndroid Build Coastguard Worker        b = pickle.loads(serialized)
7326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
7327*da0073e9SAndroid Build Coastguard Worker
7328*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7329*da0073e9SAndroid Build Coastguard Worker    def test_pickle_parameter(self):
7330*da0073e9SAndroid Build Coastguard Worker        import pickle
7331*da0073e9SAndroid Build Coastguard Worker        a = torch.nn.Parameter(torch.randn(5, 5))
7332*da0073e9SAndroid Build Coastguard Worker        serialized = pickle.dumps(a)
7333*da0073e9SAndroid Build Coastguard Worker        b = pickle.loads(serialized)
7334*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(b, torch.nn.Parameter))
7335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.requires_grad, b.requires_grad)
7336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
7337*da0073e9SAndroid Build Coastguard Worker
7338*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
7339*da0073e9SAndroid Build Coastguard Worker    def test_pickle_parameter_no_requires_grad(self):
7340*da0073e9SAndroid Build Coastguard Worker        import pickle
7341*da0073e9SAndroid Build Coastguard Worker        a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False)
7342*da0073e9SAndroid Build Coastguard Worker        serialized = pickle.dumps(a)
7343*da0073e9SAndroid Build Coastguard Worker        b = pickle.loads(serialized)
7344*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(b, torch.nn.Parameter))
7345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.requires_grad, b.requires_grad)
7346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
7347*da0073e9SAndroid Build Coastguard Worker
7348*da0073e9SAndroid Build Coastguard Worker    def test_pickle_dtype(self):
7349*da0073e9SAndroid Build Coastguard Worker        t = torch.float32
7350*da0073e9SAndroid Build Coastguard Worker        serialized = pickle.dumps(t)
7351*da0073e9SAndroid Build Coastguard Worker        b = pickle.loads(serialized)
7352*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(b, torch.dtype))
7353*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(b), id(t))
7354*da0073e9SAndroid Build Coastguard Worker
7355*da0073e9SAndroid Build Coastguard Worker    def test_pickle_size(self):
7356*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10).size()
7357*da0073e9SAndroid Build Coastguard Worker        serialized = pickle.dumps(a)
7358*da0073e9SAndroid Build Coastguard Worker        b = pickle.loads(serialized)
7359*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(b, torch.Size))
7360*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
7361*da0073e9SAndroid Build Coastguard Worker
7362*da0073e9SAndroid Build Coastguard Worker    def test_pickle_function(self):
7363*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/37703
7364*da0073e9SAndroid Build Coastguard Worker        a = torch.tanh
7365*da0073e9SAndroid Build Coastguard Worker        serialized = pickle.dumps(a)
7366*da0073e9SAndroid Build Coastguard Worker        b = pickle.loads(serialized)
7367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
7368*da0073e9SAndroid Build Coastguard Worker
7369*da0073e9SAndroid Build Coastguard Worker    def test_generator_cpu(self):
7370*da0073e9SAndroid Build Coastguard Worker        # test default generators are equal
7371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.default_generator, torch.default_generator)
7372*da0073e9SAndroid Build Coastguard Worker
7373*da0073e9SAndroid Build Coastguard Worker        # tests Generator API
7374*da0073e9SAndroid Build Coastguard Worker        # manual_seed, seed, initial_seed, get_state, set_state
7375*da0073e9SAndroid Build Coastguard Worker        g1 = torch.Generator()
7376*da0073e9SAndroid Build Coastguard Worker        g2 = torch.Generator()
7377*da0073e9SAndroid Build Coastguard Worker        g1.manual_seed(12345)
7378*da0073e9SAndroid Build Coastguard Worker        g2.manual_seed(12345)
7379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g1.initial_seed(), g2.initial_seed())
7380*da0073e9SAndroid Build Coastguard Worker
7381*da0073e9SAndroid Build Coastguard Worker        g1.seed()
7382*da0073e9SAndroid Build Coastguard Worker        g2.seed()
7383*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(g1.initial_seed(), g2.initial_seed())
7384*da0073e9SAndroid Build Coastguard Worker
7385*da0073e9SAndroid Build Coastguard Worker        g1 = torch.Generator()
7386*da0073e9SAndroid Build Coastguard Worker        g2_state = g2.get_state()
7387*da0073e9SAndroid Build Coastguard Worker        g2_randn = torch.randn(1, generator=g2)
7388*da0073e9SAndroid Build Coastguard Worker        g1.set_state(g2_state)
7389*da0073e9SAndroid Build Coastguard Worker        g1_randn = torch.randn(1, generator=g1)
7390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g1_randn, g2_randn)
7391*da0073e9SAndroid Build Coastguard Worker
7392*da0073e9SAndroid Build Coastguard Worker        default_state = torch.default_generator.get_state()
7393*da0073e9SAndroid Build Coastguard Worker        q = torch.empty(100)
7394*da0073e9SAndroid Build Coastguard Worker        g1_normal = q.normal_()
7395*da0073e9SAndroid Build Coastguard Worker        g2 = torch.Generator()
7396*da0073e9SAndroid Build Coastguard Worker        g2.set_state(default_state)
7397*da0073e9SAndroid Build Coastguard Worker        g2_normal = q.normal_(generator=g2)
7398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g1_normal, g2_normal)
7399*da0073e9SAndroid Build Coastguard Worker
7400*da0073e9SAndroid Build Coastguard Worker    def test_invalid_generator_raises(self):
7401*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.Generator('opengl'))
7402*da0073e9SAndroid Build Coastguard Worker
7403*da0073e9SAndroid Build Coastguard Worker    def test_pickle_generator(self) -> None:
7404*da0073e9SAndroid Build Coastguard Worker        devices = ['cpu']
7405*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7406*da0073e9SAndroid Build Coastguard Worker            devices += ['cuda']
7407*da0073e9SAndroid Build Coastguard Worker
7408*da0073e9SAndroid Build Coastguard Worker        for device in devices:
7409*da0073e9SAndroid Build Coastguard Worker            with self.subTest(device=device):
7410*da0073e9SAndroid Build Coastguard Worker                generator = torch.Generator(device=device).manual_seed(12345)
7411*da0073e9SAndroid Build Coastguard Worker                if device != "cpu":
7412*da0073e9SAndroid Build Coastguard Worker                    generator.set_offset(100)
7413*da0073e9SAndroid Build Coastguard Worker                torch.randn((100, 100), generator=generator, device=device)  # progress the RNG state
7414*da0073e9SAndroid Build Coastguard Worker
7415*da0073e9SAndroid Build Coastguard Worker                reserialized: torch.Generator = pickle.loads(pickle.dumps(generator))
7416*da0073e9SAndroid Build Coastguard Worker
7417*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(generator.device, reserialized.device)
7418*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(generator.initial_seed(), reserialized.initial_seed())
7419*da0073e9SAndroid Build Coastguard Worker                if device != "cpu":
7420*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(generator.get_offset(), reserialized.get_offset())
7421*da0073e9SAndroid Build Coastguard Worker                torch.testing.assert_close(generator.get_state(), reserialized.get_state())
7422*da0073e9SAndroid Build Coastguard Worker
7423*da0073e9SAndroid Build Coastguard Worker    def _sobol_reference_samples(self, scramble: bool) -> torch.Tensor:
7424*da0073e9SAndroid Build Coastguard Worker        if not scramble:
7425*da0073e9SAndroid Build Coastguard Worker            # theoretical values from Joe Kuo 2010
7426*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(
7427*da0073e9SAndroid Build Coastguard Worker                [
7428*da0073e9SAndroid Build Coastguard Worker                    [0., 0.],
7429*da0073e9SAndroid Build Coastguard Worker                    [0.5, 0.5],
7430*da0073e9SAndroid Build Coastguard Worker                    [0.75, 0.25],
7431*da0073e9SAndroid Build Coastguard Worker                    [0.25, 0.75],
7432*da0073e9SAndroid Build Coastguard Worker                    [0.375, 0.375],
7433*da0073e9SAndroid Build Coastguard Worker                    [0.875, 0.875],
7434*da0073e9SAndroid Build Coastguard Worker                    [0.625, 0.125],
7435*da0073e9SAndroid Build Coastguard Worker                    [0.125, 0.625],
7436*da0073e9SAndroid Build Coastguard Worker                ],
7437*da0073e9SAndroid Build Coastguard Worker            )
7438*da0073e9SAndroid Build Coastguard Worker        else:
7439*da0073e9SAndroid Build Coastguard Worker            # theoretical values unknown: convergence properties checked
7440*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(
7441*da0073e9SAndroid Build Coastguard Worker                [
7442*da0073e9SAndroid Build Coastguard Worker                    [0.50860737, 0.29320504],
7443*da0073e9SAndroid Build Coastguard Worker                    [0.07116939, 0.89594537],
7444*da0073e9SAndroid Build Coastguard Worker                    [0.49354145, 0.11524881],
7445*da0073e9SAndroid Build Coastguard Worker                    [0.93097717, 0.70244044],
7446*da0073e9SAndroid Build Coastguard Worker                    [0.87266153, 0.23887917],
7447*da0073e9SAndroid Build Coastguard Worker                    [0.31021884, 0.57600391],
7448*da0073e9SAndroid Build Coastguard Worker                    [0.13687253, 0.42054182],
7449*da0073e9SAndroid Build Coastguard Worker                    [0.69931293, 0.77336788],
7450*da0073e9SAndroid Build Coastguard Worker                ],
7451*da0073e9SAndroid Build Coastguard Worker            )
7452*da0073e9SAndroid Build Coastguard Worker
7453*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_bounds(self, scramble: bool = False):
7454*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(100, scramble=scramble, seed=123456)
7455*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw(512)
7456*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(sample >= 0))
7457*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(sample <= 1))
7458*da0073e9SAndroid Build Coastguard Worker
7459*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_bounds_scrambled(self):
7460*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_bounds(scramble=True)
7461*da0073e9SAndroid Build Coastguard Worker
7462*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_draw(self, scramble: bool = False):
7463*da0073e9SAndroid Build Coastguard Worker        ref_sample = self._sobol_reference_samples(scramble=scramble)
7464*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7465*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw(n=len(ref_sample))
7466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sample, ref_sample)
7467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(engine.num_generated, len(ref_sample))
7468*da0073e9SAndroid Build Coastguard Worker
7469*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_draw_scrambled(self):
7470*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_draw(scramble=True)
7471*da0073e9SAndroid Build Coastguard Worker
7472*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_first_point(self):
7473*da0073e9SAndroid Build Coastguard Worker        for dtype in (torch.float, torch.double):
7474*da0073e9SAndroid Build Coastguard Worker            engine = torch.quasirandom.SobolEngine(2, scramble=False)
7475*da0073e9SAndroid Build Coastguard Worker            sample = engine.draw(1, dtype=dtype)
7476*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.all(sample == 0))
7477*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample.dtype, dtype)
7478*da0073e9SAndroid Build Coastguard Worker        for dtype in (torch.float, torch.double):
7479*da0073e9SAndroid Build Coastguard Worker            engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456)
7480*da0073e9SAndroid Build Coastguard Worker            sample = engine.draw(1, dtype=dtype)
7481*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.all(sample != 0))
7482*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample.dtype, dtype)
7483*da0073e9SAndroid Build Coastguard Worker
7484*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_continuing(self, scramble: bool = False):
7485*da0073e9SAndroid Build Coastguard Worker        ref_sample = self._sobol_reference_samples(scramble=scramble)
7486*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7487*da0073e9SAndroid Build Coastguard Worker        n_half = len(ref_sample) // 2
7488*da0073e9SAndroid Build Coastguard Worker        _ = engine.draw(n=n_half)
7489*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw(n=n_half)
7490*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(sample, ref_sample[n_half:])
7491*da0073e9SAndroid Build Coastguard Worker
7492*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_continuing_scrambled(self):
7493*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_continuing(scramble=True)
7494*da0073e9SAndroid Build Coastguard Worker
7495*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_reset(self, scramble: bool = False):
7496*da0073e9SAndroid Build Coastguard Worker        ref_sample = self._sobol_reference_samples(scramble=scramble)
7497*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7498*da0073e9SAndroid Build Coastguard Worker        _ = engine.draw(n=len(ref_sample) // 2)
7499*da0073e9SAndroid Build Coastguard Worker        engine.reset()
7500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(engine.num_generated, 0)
7501*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw(n=len(ref_sample))
7502*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(sample, ref_sample)
7503*da0073e9SAndroid Build Coastguard Worker
7504*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_reset_scrambled(self):
7505*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_reset(scramble=True)
7506*da0073e9SAndroid Build Coastguard Worker
7507*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_fast_forward(self, scramble: bool = False):
7508*da0073e9SAndroid Build Coastguard Worker        ref_sample = self._sobol_reference_samples(scramble=scramble)
7509*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7510*da0073e9SAndroid Build Coastguard Worker        engine.fast_forward(4)
7511*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw(n=4)
7512*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(sample, ref_sample[4:])
7513*da0073e9SAndroid Build Coastguard Worker        # alternate fast forwarding with sampling
7514*da0073e9SAndroid Build Coastguard Worker        engine.reset()
7515*da0073e9SAndroid Build Coastguard Worker        even_draws = []
7516*da0073e9SAndroid Build Coastguard Worker        for i in range(8):
7517*da0073e9SAndroid Build Coastguard Worker            if i % 2 == 0:
7518*da0073e9SAndroid Build Coastguard Worker                even_draws.append(engine.draw())
7519*da0073e9SAndroid Build Coastguard Worker            else:
7520*da0073e9SAndroid Build Coastguard Worker                engine.fast_forward(1)
7521*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
7522*da0073e9SAndroid Build Coastguard Worker            ref_sample[[i for i in range(8) if i % 2 == 0]],
7523*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(np.concatenate(even_draws)),
7524*da0073e9SAndroid Build Coastguard Worker        )
7525*da0073e9SAndroid Build Coastguard Worker
7526*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_fast_forward_scrambled(self):
7527*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_fast_forward(scramble=True)
7528*da0073e9SAndroid Build Coastguard Worker
7529*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_default_dtype(self):
7530*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7531*da0073e9SAndroid Build Coastguard Worker        # Check that default dtype is correctly handled
7532*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(engine.draw(n=5).dtype, torch.float32)
7533*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float64):
7534*da0073e9SAndroid Build Coastguard Worker            engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7535*da0073e9SAndroid Build Coastguard Worker            # Check that default dtype is correctly handled (when set to float64)
7536*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(engine.draw(n=5).dtype, torch.float64)
7537*da0073e9SAndroid Build Coastguard Worker            # Check that explicitly passed dtype is adhered to
7538*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32)
7539*da0073e9SAndroid Build Coastguard Worker            # Reinitialize the engine and check that first draw dtype is correctly handled
7540*da0073e9SAndroid Build Coastguard Worker            engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456)
7541*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32)
7542*da0073e9SAndroid Build Coastguard Worker
7543*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("np.float64 restored as float32 after graph break.")
7544*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_distribution(self, scramble=False):
7545*da0073e9SAndroid Build Coastguard Worker        d = 50
7546*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456)
7547*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw(1024)
7548*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
7549*da0073e9SAndroid Build Coastguard Worker            torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2
7550*da0073e9SAndroid Build Coastguard Worker        )
7551*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
7552*da0073e9SAndroid Build Coastguard Worker            np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2
7553*da0073e9SAndroid Build Coastguard Worker        )
7554*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
7555*da0073e9SAndroid Build Coastguard Worker            np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2
7556*da0073e9SAndroid Build Coastguard Worker        )
7557*da0073e9SAndroid Build Coastguard Worker
7558*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("np.float64 restored as float32 after graph break.")
7559*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_distribution_scrambled(self):
7560*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_distribution(scramble=True)
7561*da0073e9SAndroid Build Coastguard Worker
7562*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_draw_base2(self, scramble=False):
7563*da0073e9SAndroid Build Coastguard Worker        ref_sample = self._sobol_reference_samples(scramble=scramble)
7564*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
7565*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw_base2(2)
7566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_sample[:4], sample)
7567*da0073e9SAndroid Build Coastguard Worker        # resampling still having N=2**n
7568*da0073e9SAndroid Build Coastguard Worker        sample = engine.draw_base2(2)
7569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_sample[4:8], sample)
7570*da0073e9SAndroid Build Coastguard Worker
7571*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_draw_base2_scrambled(self):
7572*da0073e9SAndroid Build Coastguard Worker        self.test_sobolengine_draw_base2(scramble=True)
7573*da0073e9SAndroid Build Coastguard Worker
7574*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_raise(self):
7575*da0073e9SAndroid Build Coastguard Worker        maxdim = torch.quasirandom.SobolEngine.MAXDIM
7576*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
7577*da0073e9SAndroid Build Coastguard Worker            torch.quasirandom.SobolEngine(maxdim + 1)
7578*da0073e9SAndroid Build Coastguard Worker
7579*da0073e9SAndroid Build Coastguard Worker    def test_sobolengine_high_dim(self):
7580*da0073e9SAndroid Build Coastguard Worker        engine = torch.quasirandom.SobolEngine(1111, scramble=False, seed=123456)
7581*da0073e9SAndroid Build Coastguard Worker        samples1 = engine.draw()
7582*da0073e9SAndroid Build Coastguard Worker        vals1, counts1 = torch.unique(samples1, return_counts=True)
7583*da0073e9SAndroid Build Coastguard Worker        samples2 = engine.draw()
7584*da0073e9SAndroid Build Coastguard Worker        vals2, counts2 = torch.unique(samples2, return_counts=True)
7585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vals1.item(), 0.0)
7586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts1.item(), 1111)
7587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vals2.item(), 0.5)
7588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts1.item(), 1111)
7589*da0073e9SAndroid Build Coastguard Worker
7590*da0073e9SAndroid Build Coastguard Worker    def test_parsing_int64(self):
7591*da0073e9SAndroid Build Coastguard Worker        # accepts integer arguments
7592*da0073e9SAndroid Build Coastguard Worker        x = torch.cumsum(torch.ones(5, 5), 0)
7593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0)))
7594*da0073e9SAndroid Build Coastguard Worker        # doesn't accept floating point variables
7595*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.)))
7596*da0073e9SAndroid Build Coastguard Worker
7597*da0073e9SAndroid Build Coastguard Worker    def test_parsing_double(self):
7598*da0073e9SAndroid Build Coastguard Worker        # accepts floating point and integer arguments
7599*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
7600*da0073e9SAndroid Build Coastguard Worker        torch.isclose(x, x, 1, 1)
7601*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isclose(x, x, 1, 1).all())
7602*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isclose(x, x, 1.5, 1.).all())
7603*da0073e9SAndroid Build Coastguard Worker        # accepts floating point and integer tensors
7604*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all())
7605*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all())
7606*da0073e9SAndroid Build Coastguard Worker        # doesn't accept variables with requires_grad
7607*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError,
7608*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all())
7609*da0073e9SAndroid Build Coastguard Worker
7610*da0073e9SAndroid Build Coastguard Worker    def test_parsing_intlist(self):
7611*da0073e9SAndroid Build Coastguard Worker        #  parse with integer variables
7612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape)
7613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape)
7614*da0073e9SAndroid Build Coastguard Worker        # parse with numpy integers
7615*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape)
7616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape)
7617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape)
7618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape)
7619*da0073e9SAndroid Build Coastguard Worker
7620*da0073e9SAndroid Build Coastguard Worker        # fail parse with float variables
7621*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4))))
7622*da0073e9SAndroid Build Coastguard Worker        # fail parse with numpy floats
7623*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones((3., torch.tensor(4))))
7624*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4))))
7625*da0073e9SAndroid Build Coastguard Worker
7626*da0073e9SAndroid Build Coastguard Worker        # fail parse with > 1 element variables
7627*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3)))
7628*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3)))
7629*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
7630*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
7631*da0073e9SAndroid Build Coastguard Worker
7632*da0073e9SAndroid Build Coastguard Worker        # fail parse with additional positional args after intlist arg
7633*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError,
7634*da0073e9SAndroid Build Coastguard Worker                               "received an invalid combination of arguments",
7635*da0073e9SAndroid Build Coastguard Worker                               lambda: torch.LongTensor((6, 0), 1, 1, 0))
7636*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError,
7637*da0073e9SAndroid Build Coastguard Worker                               "missing 1 required positional arguments",
7638*da0073e9SAndroid Build Coastguard Worker                               lambda: torch.tensor().new_zeros((5, 5), 0))
7639*da0073e9SAndroid Build Coastguard Worker
7640*da0073e9SAndroid Build Coastguard Worker    def test_from_buffer(self):
7641*da0073e9SAndroid Build Coastguard Worker        a = bytearray([1, 2, 3, 4])
7642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
7643*da0073e9SAndroid Build Coastguard Worker        shorts = torch.ShortStorage.from_buffer(a, 'big')
7644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shorts.size(), 2)
7645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shorts.tolist(), [258, 772])
7646*da0073e9SAndroid Build Coastguard Worker        ints = torch.IntStorage.from_buffer(a, 'little')
7647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ints.size(), 1)
7648*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ints[0], 67305985)
7649*da0073e9SAndroid Build Coastguard Worker        f = bytearray([0x40, 0x10, 0x00, 0x00])
7650*da0073e9SAndroid Build Coastguard Worker        floats = torch.FloatStorage.from_buffer(f, 'big')
7651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(floats.size(), 1)
7652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(floats[0], 2.25)
7653*da0073e9SAndroid Build Coastguard Worker
7654*da0073e9SAndroid Build Coastguard Worker        f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
7655*da0073e9SAndroid Build Coastguard Worker        bools = torch.BoolStorage.from_buffer(f, 'big')
7656*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bools.size(), 8)
7657*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
7658*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bools.type(), 'torch.BoolStorage')
7659*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(bools, torch.BoolStorage))
7660*da0073e9SAndroid Build Coastguard Worker
7661*da0073e9SAndroid Build Coastguard Worker        f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
7662*da0073e9SAndroid Build Coastguard Worker        bools = torch.BoolStorage.from_buffer(f, 'big')
7663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bools.size(), 19)
7664*da0073e9SAndroid Build Coastguard Worker
7665*da0073e9SAndroid Build Coastguard Worker        f = bytearray(b'\0x4A')
7666*da0073e9SAndroid Build Coastguard Worker        bools = torch.BoolStorage.from_buffer(f, 'big')
7667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bools.size(), 4)
7668*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bools.tolist(), [False, True, True, True])
7669*da0073e9SAndroid Build Coastguard Worker        bytes = torch.ByteStorage.from_buffer(a)
7670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bytes.nbytes(), 4)
7671*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bytes.tolist(), [1, 2, 3, 4])
7672*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(bytes, torch.ByteStorage))
7673*da0073e9SAndroid Build Coastguard Worker
7674*da0073e9SAndroid Build Coastguard Worker    def test_storage_error(self):
7675*da0073e9SAndroid Build Coastguard Worker        quantized_storages = [
7676*da0073e9SAndroid Build Coastguard Worker            torch.QInt32Storage,
7677*da0073e9SAndroid Build Coastguard Worker            torch.QInt8Storage,
7678*da0073e9SAndroid Build Coastguard Worker            torch.QUInt2x4Storage,
7679*da0073e9SAndroid Build Coastguard Worker            torch.QUInt4x2Storage,
7680*da0073e9SAndroid Build Coastguard Worker            torch.QUInt8Storage,
7681*da0073e9SAndroid Build Coastguard Worker        ]
7682*da0073e9SAndroid Build Coastguard Worker
7683*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Only child classes of _LegacyStorage can be instantiated"):
7684*da0073e9SAndroid Build Coastguard Worker            torch.storage._LegacyStorage()
7685*da0073e9SAndroid Build Coastguard Worker
7686*da0073e9SAndroid Build Coastguard Worker        for storage_class in torch._storage_classes:
7687*da0073e9SAndroid Build Coastguard Worker            if storage_class in [torch.UntypedStorage, torch.TypedStorage]:
7688*da0073e9SAndroid Build Coastguard Worker                continue
7689*da0073e9SAndroid Build Coastguard Worker
7690*da0073e9SAndroid Build Coastguard Worker            device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
7691*da0073e9SAndroid Build Coastguard Worker            dtype = storage_class.dtype
7692*da0073e9SAndroid Build Coastguard Worker
7693*da0073e9SAndroid Build Coastguard Worker            if device == 'cuda' and not torch.cuda.is_available():
7694*da0073e9SAndroid Build Coastguard Worker                continue
7695*da0073e9SAndroid Build Coastguard Worker
7696*da0073e9SAndroid Build Coastguard Worker            # Legacy <type>Storage constructor errors
7697*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"'device' cannot be specified"):
7698*da0073e9SAndroid Build Coastguard Worker                storage_class(device='cpu')
7699*da0073e9SAndroid Build Coastguard Worker
7700*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"'dtype' cannot be specified"):
7701*da0073e9SAndroid Build Coastguard Worker                storage_class(dtype=torch.float)
7702*da0073e9SAndroid Build Coastguard Worker
7703*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"got an unexpected keyword"):
7704*da0073e9SAndroid Build Coastguard Worker                storage_class(sdlkjf=torch.float)
7705*da0073e9SAndroid Build Coastguard Worker
7706*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
7707*da0073e9SAndroid Build Coastguard Worker                storage_class(0, 0)
7708*da0073e9SAndroid Build Coastguard Worker
7709*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"invalid data type"):
7710*da0073e9SAndroid Build Coastguard Worker                storage_class('string')
7711*da0073e9SAndroid Build Coastguard Worker
7712*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
7713*da0073e9SAndroid Build Coastguard Worker                storage_class(torch.tensor([]))
7714*da0073e9SAndroid Build Coastguard Worker
7715*da0073e9SAndroid Build Coastguard Worker            s = storage_class()
7716*da0073e9SAndroid Build Coastguard Worker
7717*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
7718*da0073e9SAndroid Build Coastguard Worker                storage_class(0, wrap_storage=s.untyped())
7719*da0073e9SAndroid Build Coastguard Worker
7720*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"):
7721*da0073e9SAndroid Build Coastguard Worker                storage_class(wrap_storage=s)
7722*da0073e9SAndroid Build Coastguard Worker
7723*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
7724*da0073e9SAndroid Build Coastguard Worker                if storage_class in quantized_storages:
7725*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
7726*da0073e9SAndroid Build Coastguard Worker                        s.cuda()
7727*da0073e9SAndroid Build Coastguard Worker
7728*da0073e9SAndroid Build Coastguard Worker                else:
7729*da0073e9SAndroid Build Coastguard Worker
7730*da0073e9SAndroid Build Coastguard Worker                    if s.is_cuda:
7731*da0073e9SAndroid Build Coastguard Worker                        s_other_device = s.cpu()
7732*da0073e9SAndroid Build Coastguard Worker                    else:
7733*da0073e9SAndroid Build Coastguard Worker                        s_other_device = s.cuda()
7734*da0073e9SAndroid Build Coastguard Worker
7735*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
7736*da0073e9SAndroid Build Coastguard Worker                        storage_class(wrap_storage=s_other_device.untyped())
7737*da0073e9SAndroid Build Coastguard Worker
7738*da0073e9SAndroid Build Coastguard Worker            # TypedStorage constructor errors
7739*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
7740*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype)
7741*da0073e9SAndroid Build Coastguard Worker
7742*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
7743*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(wrap_storage=s.untyped())
7744*da0073e9SAndroid Build Coastguard Worker
7745*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
7746*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(wrap_storage=s.untyped(), dtype=0)
7747*da0073e9SAndroid Build Coastguard Worker
7748*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
7749*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device)
7750*da0073e9SAndroid Build Coastguard Worker
7751*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"):
7752*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(wrap_storage=s, dtype=dtype)
7753*da0073e9SAndroid Build Coastguard Worker
7754*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
7755*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(dtype=dtype, device='xla')
7756*da0073e9SAndroid Build Coastguard Worker
7757*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
7758*da0073e9SAndroid Build Coastguard Worker                if storage_class in quantized_storages:
7759*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
7760*da0073e9SAndroid Build Coastguard Worker                        torch.TypedStorage(dtype=dtype, device='cuda')
7761*da0073e9SAndroid Build Coastguard Worker
7762*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
7763*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device)
7764*da0073e9SAndroid Build Coastguard Worker
7765*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
7766*da0073e9SAndroid Build Coastguard Worker                torch.TypedStorage(0, 0, dtype=dtype, device=device)
7767*da0073e9SAndroid Build Coastguard Worker
7768*da0073e9SAndroid Build Coastguard Worker            if isinstance(s, torch.TypedStorage):
7769*da0073e9SAndroid Build Coastguard Worker                s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
7770*da0073e9SAndroid Build Coastguard Worker
7771*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, r'cannot set item'):
7772*da0073e9SAndroid Build Coastguard Worker                    s.fill_(s_other)
7773*da0073e9SAndroid Build Coastguard Worker
7774*da0073e9SAndroid Build Coastguard Worker    def test_storage_error_no_attribute(self):
7775*da0073e9SAndroid Build Coastguard Worker        storage_classes = [
7776*da0073e9SAndroid Build Coastguard Worker            torch.cuda.ByteStorage,
7777*da0073e9SAndroid Build Coastguard Worker            torch.cuda.FloatStorage,
7778*da0073e9SAndroid Build Coastguard Worker        ]
7779*da0073e9SAndroid Build Coastguard Worker        for storage_class in storage_classes:
7780*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
7781*da0073e9SAndroid Build Coastguard Worker                storage_class.from_buffer()
7782*da0073e9SAndroid Build Coastguard Worker
7783*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
7784*da0073e9SAndroid Build Coastguard Worker                storage_class._new_with_weak_ptr()
7785*da0073e9SAndroid Build Coastguard Worker
7786*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
7787*da0073e9SAndroid Build Coastguard Worker                storage_class._new_shared_filename(0, 0, 0)
7788*da0073e9SAndroid Build Coastguard Worker
7789*da0073e9SAndroid Build Coastguard Worker    def test_storage_casts(self):
7790*da0073e9SAndroid Build Coastguard Worker        storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
7791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage.size(), 6)
7792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
7793*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage.type(), 'torch.IntStorage')
7794*da0073e9SAndroid Build Coastguard Worker        self.assertIs(storage.dtype, torch.int32)
7795*da0073e9SAndroid Build Coastguard Worker
7796*da0073e9SAndroid Build Coastguard Worker        floatStorage = storage.float()
7797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(floatStorage.size(), 6)
7798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
7800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7801*da0073e9SAndroid Build Coastguard Worker        self.assertIs(floatStorage.dtype, torch.float32)
7802*da0073e9SAndroid Build Coastguard Worker
7803*da0073e9SAndroid Build Coastguard Worker        halfStorage = storage.half()
7804*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfStorage.size(), 6)
7805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7806*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
7807*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7808*da0073e9SAndroid Build Coastguard Worker        self.assertIs(halfStorage.dtype, torch.float16)
7809*da0073e9SAndroid Build Coastguard Worker
7810*da0073e9SAndroid Build Coastguard Worker        bfloat16Storage = storage.bfloat16()
7811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bfloat16Storage.size(), 6)
7812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4])
7813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage')
7814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7815*da0073e9SAndroid Build Coastguard Worker        self.assertIs(bfloat16Storage.dtype, torch.bfloat16)
7816*da0073e9SAndroid Build Coastguard Worker
7817*da0073e9SAndroid Build Coastguard Worker        longStorage = storage.long()
7818*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(longStorage.size(), 6)
7819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(longStorage.type(), 'torch.LongStorage')
7821*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7822*da0073e9SAndroid Build Coastguard Worker        self.assertIs(longStorage.dtype, torch.int64)
7823*da0073e9SAndroid Build Coastguard Worker
7824*da0073e9SAndroid Build Coastguard Worker        shortStorage = storage.short()
7825*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shortStorage.size(), 6)
7826*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
7827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
7828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7829*da0073e9SAndroid Build Coastguard Worker        self.assertIs(shortStorage.dtype, torch.int16)
7830*da0073e9SAndroid Build Coastguard Worker
7831*da0073e9SAndroid Build Coastguard Worker        doubleStorage = storage.double()
7832*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(doubleStorage.size(), 6)
7833*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
7834*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
7835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7836*da0073e9SAndroid Build Coastguard Worker        self.assertIs(doubleStorage.dtype, torch.float64)
7837*da0073e9SAndroid Build Coastguard Worker
7838*da0073e9SAndroid Build Coastguard Worker        charStorage = storage.char()
7839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(charStorage.size(), 6)
7840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
7841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(charStorage.type(), 'torch.CharStorage')
7842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
7843*da0073e9SAndroid Build Coastguard Worker        self.assertIs(charStorage.dtype, torch.int8)
7844*da0073e9SAndroid Build Coastguard Worker
7845*da0073e9SAndroid Build Coastguard Worker        byteStorage = storage.byte()
7846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byteStorage.size(), 6)
7847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
7848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
7849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
7850*da0073e9SAndroid Build Coastguard Worker        self.assertIs(byteStorage.dtype, torch.uint8)
7851*da0073e9SAndroid Build Coastguard Worker
7852*da0073e9SAndroid Build Coastguard Worker        boolStorage = storage.bool()
7853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(boolStorage.size(), 6)
7854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
7855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
7856*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
7857*da0073e9SAndroid Build Coastguard Worker        self.assertIs(boolStorage.dtype, torch.bool)
7858*da0073e9SAndroid Build Coastguard Worker
7859*da0073e9SAndroid Build Coastguard Worker        complexfloat_storage = torch.ComplexFloatStorage([-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
7860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexfloat_storage.size(), 6)
7861*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexfloat_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
7862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexfloat_storage.type(), 'torch.ComplexFloatStorage')
7863*da0073e9SAndroid Build Coastguard Worker        self.assertIs(complexfloat_storage.dtype, torch.complex64)
7864*da0073e9SAndroid Build Coastguard Worker
7865*da0073e9SAndroid Build Coastguard Worker        complexdouble_storage = complexfloat_storage.complex_double()
7866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexdouble_storage.size(), 6)
7867*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexdouble_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
7868*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
7869*da0073e9SAndroid Build Coastguard Worker        self.assertIs(complexdouble_storage.dtype, torch.complex128)
7870*da0073e9SAndroid Build Coastguard Worker
7871*da0073e9SAndroid Build Coastguard Worker    def test_storage_byteswap(self):
7872*da0073e9SAndroid Build Coastguard Worker        input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
7873*da0073e9SAndroid Build Coastguard Worker        swapped_8bytes = [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8]
7874*da0073e9SAndroid Build Coastguard Worker        swapped_4bytes = [3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12]
7875*da0073e9SAndroid Build Coastguard Worker        swapped_2bytes = [1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14]
7876*da0073e9SAndroid Build Coastguard Worker        swapped_1byte = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
7877*da0073e9SAndroid Build Coastguard Worker
7878*da0073e9SAndroid Build Coastguard Worker        storage = torch.storage.TypedStorage(input, dtype=torch.uint8)._untyped_storage
7879*da0073e9SAndroid Build Coastguard Worker
7880*da0073e9SAndroid Build Coastguard Worker        storage_f64 = storage.__copy__()
7881*da0073e9SAndroid Build Coastguard Worker        storage_f64.byteswap(torch.float64)
7882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_f64.tolist(), swapped_8bytes)
7883*da0073e9SAndroid Build Coastguard Worker
7884*da0073e9SAndroid Build Coastguard Worker        storage_f32 = storage.__copy__()
7885*da0073e9SAndroid Build Coastguard Worker        storage_f32.byteswap(torch.float32)
7886*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_f32.tolist(), swapped_4bytes)
7887*da0073e9SAndroid Build Coastguard Worker
7888*da0073e9SAndroid Build Coastguard Worker        storage_f16 = storage.__copy__()
7889*da0073e9SAndroid Build Coastguard Worker        storage_f16.byteswap(torch.float16)
7890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_f16.tolist(), swapped_2bytes)
7891*da0073e9SAndroid Build Coastguard Worker
7892*da0073e9SAndroid Build Coastguard Worker        storage_bf16 = storage.__copy__()
7893*da0073e9SAndroid Build Coastguard Worker        storage_bf16.byteswap(torch.bfloat16)
7894*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_bf16.tolist(), swapped_2bytes)
7895*da0073e9SAndroid Build Coastguard Worker
7896*da0073e9SAndroid Build Coastguard Worker        storage_i64 = storage.__copy__()
7897*da0073e9SAndroid Build Coastguard Worker        storage_i64.byteswap(torch.int64)
7898*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_i64.tolist(), swapped_8bytes)
7899*da0073e9SAndroid Build Coastguard Worker
7900*da0073e9SAndroid Build Coastguard Worker        storage_i32 = storage.__copy__()
7901*da0073e9SAndroid Build Coastguard Worker        storage_i32.byteswap(torch.int32)
7902*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_i32.tolist(), swapped_4bytes)
7903*da0073e9SAndroid Build Coastguard Worker
7904*da0073e9SAndroid Build Coastguard Worker        storage_i16 = storage.__copy__()
7905*da0073e9SAndroid Build Coastguard Worker        storage_i16.byteswap(torch.int16)
7906*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_i16.tolist(), swapped_2bytes)
7907*da0073e9SAndroid Build Coastguard Worker
7908*da0073e9SAndroid Build Coastguard Worker        storage_i8 = storage.__copy__()
7909*da0073e9SAndroid Build Coastguard Worker        storage_i8.byteswap(torch.int8)
7910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_i8.tolist(), swapped_1byte)
7911*da0073e9SAndroid Build Coastguard Worker
7912*da0073e9SAndroid Build Coastguard Worker        storage_ui8 = storage.__copy__()
7913*da0073e9SAndroid Build Coastguard Worker        storage_ui8.byteswap(torch.uint8)
7914*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_ui8.tolist(), swapped_1byte)
7915*da0073e9SAndroid Build Coastguard Worker
7916*da0073e9SAndroid Build Coastguard Worker        storage_bool = storage.__copy__()
7917*da0073e9SAndroid Build Coastguard Worker        storage_bool.byteswap(torch.bool)
7918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_bool.tolist(), swapped_1byte)
7919*da0073e9SAndroid Build Coastguard Worker
7920*da0073e9SAndroid Build Coastguard Worker        storage_c128 = storage.__copy__()
7921*da0073e9SAndroid Build Coastguard Worker        storage_c128.byteswap(torch.complex128)
7922*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_c128.tolist(), swapped_8bytes)
7923*da0073e9SAndroid Build Coastguard Worker
7924*da0073e9SAndroid Build Coastguard Worker        storage_c64 = storage.__copy__()
7925*da0073e9SAndroid Build Coastguard Worker        storage_c64.byteswap(torch.complex64)
7926*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage_c64.tolist(), swapped_4bytes)
7927*da0073e9SAndroid Build Coastguard Worker
7928*da0073e9SAndroid Build Coastguard Worker    # Test that internal versions of functions related to TypedStorage do not
7929*da0073e9SAndroid Build Coastguard Worker    # produce a deprecation warning
7930*da0073e9SAndroid Build Coastguard Worker    def test_typed_storage_internal_no_warning(self):
7931*da0073e9SAndroid Build Coastguard Worker        s0 = torch.FloatStorage(10)
7932*da0073e9SAndroid Build Coastguard Worker        s0_untyped = s0.untyped()
7933*da0073e9SAndroid Build Coastguard Worker        t0 = torch.randn(10)
7934*da0073e9SAndroid Build Coastguard Worker
7935*da0073e9SAndroid Build Coastguard Worker        funcs = [
7936*da0073e9SAndroid Build Coastguard Worker            lambda: torch.FloatStorage(_internal=True),
7937*da0073e9SAndroid Build Coastguard Worker            lambda: torch.TypedStorage(
7938*da0073e9SAndroid Build Coastguard Worker                dtype=torch.float,
7939*da0073e9SAndroid Build Coastguard Worker                device='cpu',
7940*da0073e9SAndroid Build Coastguard Worker                _internal=True),
7941*da0073e9SAndroid Build Coastguard Worker            lambda: torch.TypedStorage(
7942*da0073e9SAndroid Build Coastguard Worker                wrap_storage=s0_untyped,
7943*da0073e9SAndroid Build Coastguard Worker                dtype=s0.dtype,
7944*da0073e9SAndroid Build Coastguard Worker                _internal=True),
7945*da0073e9SAndroid Build Coastguard Worker            lambda: torch.FloatStorage._dtype,
7946*da0073e9SAndroid Build Coastguard Worker            lambda: s0._resize_(20),
7947*da0073e9SAndroid Build Coastguard Worker            lambda: s0._size(),
7948*da0073e9SAndroid Build Coastguard Worker            lambda: s0._untyped_storage,
7949*da0073e9SAndroid Build Coastguard Worker            lambda: s0._is_shared(),
7950*da0073e9SAndroid Build Coastguard Worker            lambda: s0._share_memory_(),
7951*da0073e9SAndroid Build Coastguard Worker            lambda: s0._pickle_storage_type(),
7952*da0073e9SAndroid Build Coastguard Worker            lambda: s0._setitem(slice(0, s0._size()), 1),
7953*da0073e9SAndroid Build Coastguard Worker            lambda: s0._element_size(),
7954*da0073e9SAndroid Build Coastguard Worker            lambda: s0._deepcopy({}),
7955*da0073e9SAndroid Build Coastguard Worker            lambda: s0._data_ptr(),
7956*da0073e9SAndroid Build Coastguard Worker            lambda: s0._nbytes(),
7957*da0073e9SAndroid Build Coastguard Worker            lambda: t0._typed_storage(),
7958*da0073e9SAndroid Build Coastguard Worker        ]
7959*da0073e9SAndroid Build Coastguard Worker
7960*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7961*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.FloatStorage(10)
7962*da0073e9SAndroid Build Coastguard Worker            s1_untyped = s1.untyped()
7963*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(10, device='cuda')
7964*da0073e9SAndroid Build Coastguard Worker
7965*da0073e9SAndroid Build Coastguard Worker            funcs += [
7966*da0073e9SAndroid Build Coastguard Worker                lambda: torch.cuda.FloatStorage(_internal=True),
7967*da0073e9SAndroid Build Coastguard Worker                lambda: torch.TypedStorage(
7968*da0073e9SAndroid Build Coastguard Worker                    dtype=torch.float,
7969*da0073e9SAndroid Build Coastguard Worker                    device='cuda',
7970*da0073e9SAndroid Build Coastguard Worker                    _internal=True),
7971*da0073e9SAndroid Build Coastguard Worker                lambda: torch.TypedStorage(
7972*da0073e9SAndroid Build Coastguard Worker                    wrap_storage=s1_untyped,
7973*da0073e9SAndroid Build Coastguard Worker                    dtype=s1.dtype,
7974*da0073e9SAndroid Build Coastguard Worker                    _internal=True),
7975*da0073e9SAndroid Build Coastguard Worker                lambda: torch.cuda.FloatStorage._dtype,
7976*da0073e9SAndroid Build Coastguard Worker                lambda: s1._resize_(20),
7977*da0073e9SAndroid Build Coastguard Worker                lambda: s1._size(),
7978*da0073e9SAndroid Build Coastguard Worker                lambda: s1._untyped_storage,
7979*da0073e9SAndroid Build Coastguard Worker                lambda: s1._is_shared(),
7980*da0073e9SAndroid Build Coastguard Worker                lambda: s1._share_memory_(),
7981*da0073e9SAndroid Build Coastguard Worker                lambda: s1._pickle_storage_type(),
7982*da0073e9SAndroid Build Coastguard Worker                lambda: s1._setitem(slice(0, s1._size()), 1),
7983*da0073e9SAndroid Build Coastguard Worker                lambda: s1._element_size(),
7984*da0073e9SAndroid Build Coastguard Worker                lambda: s1._deepcopy({}),
7985*da0073e9SAndroid Build Coastguard Worker                lambda: s1._data_ptr(),
7986*da0073e9SAndroid Build Coastguard Worker                lambda: s1._nbytes(),
7987*da0073e9SAndroid Build Coastguard Worker                lambda: t1._typed_storage(),
7988*da0073e9SAndroid Build Coastguard Worker            ]
7989*da0073e9SAndroid Build Coastguard Worker
7990*da0073e9SAndroid Build Coastguard Worker        # Check that each of the TypedStorage internal function calls do not
7991*da0073e9SAndroid Build Coastguard Worker        # produce a deprecation warning
7992*da0073e9SAndroid Build Coastguard Worker        for f in funcs:
7993*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings():
7994*da0073e9SAndroid Build Coastguard Worker                warnings.filterwarnings('error', "TypedStorage is deprecated")
7995*da0073e9SAndroid Build Coastguard Worker                f()
7996*da0073e9SAndroid Build Coastguard Worker
7997*da0073e9SAndroid Build Coastguard Worker    # Test that public functions related to TypedStorage produce a deprecation
7998*da0073e9SAndroid Build Coastguard Worker    # warning
7999*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("FIXME")
8000*da0073e9SAndroid Build Coastguard Worker    def test_typed_storage_deprecation_warning(self):
8001*da0073e9SAndroid Build Coastguard Worker        s0 = torch.FloatStorage(10)
8002*da0073e9SAndroid Build Coastguard Worker        funcs = [
8003*da0073e9SAndroid Build Coastguard Worker            lambda: torch.FloatStorage(),
8004*da0073e9SAndroid Build Coastguard Worker            lambda: torch.FloatStorage.dtype,
8005*da0073e9SAndroid Build Coastguard Worker            lambda: s0.fill_(0),
8006*da0073e9SAndroid Build Coastguard Worker            lambda: s0.is_cuda,
8007*da0073e9SAndroid Build Coastguard Worker            lambda: s0.untyped(),
8008*da0073e9SAndroid Build Coastguard Worker            lambda: len(s0),
8009*da0073e9SAndroid Build Coastguard Worker            lambda: s0[0],
8010*da0073e9SAndroid Build Coastguard Worker        ]
8011*da0073e9SAndroid Build Coastguard Worker
8012*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
8013*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.FloatStorage(10)
8014*da0073e9SAndroid Build Coastguard Worker            funcs += [
8015*da0073e9SAndroid Build Coastguard Worker                lambda: torch.cuda.FloatStorage(),
8016*da0073e9SAndroid Build Coastguard Worker                lambda: torch.cuda.FloatStorage.dtype,
8017*da0073e9SAndroid Build Coastguard Worker                lambda: s1.fill_(0),
8018*da0073e9SAndroid Build Coastguard Worker                lambda: s1.is_cuda,
8019*da0073e9SAndroid Build Coastguard Worker                lambda: s1.untyped(),
8020*da0073e9SAndroid Build Coastguard Worker                lambda: len(s1),
8021*da0073e9SAndroid Build Coastguard Worker                lambda: s1[0],
8022*da0073e9SAndroid Build Coastguard Worker            ]
8023*da0073e9SAndroid Build Coastguard Worker
8024*da0073e9SAndroid Build Coastguard Worker        # Check that each of the TypedStorage function calls produce a warning
8025*da0073e9SAndroid Build Coastguard Worker        # if warnings are reset between each
8026*da0073e9SAndroid Build Coastguard Worker        for f in funcs:
8027*da0073e9SAndroid Build Coastguard Worker            with AlwaysWarnTypedStorageRemoval(True):
8028*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as w:
8029*da0073e9SAndroid Build Coastguard Worker                    warnings.resetwarnings()
8030*da0073e9SAndroid Build Coastguard Worker                    f()
8031*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
8032*da0073e9SAndroid Build Coastguard Worker                    warning = w[0].message
8033*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(warning, DeprecationWarning)
8034*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(re.search(
8035*da0073e9SAndroid Build Coastguard Worker                        '^TypedStorage is deprecated',
8036*da0073e9SAndroid Build Coastguard Worker                        str(warning)))
8037*da0073e9SAndroid Build Coastguard Worker
8038*da0073e9SAndroid Build Coastguard Worker        # Test that only the first warning is raised by default
8039*da0073e9SAndroid Build Coastguard Worker        torch.storage._reset_warn_typed_storage_removal()
8040*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
8041*da0073e9SAndroid Build Coastguard Worker            warnings.resetwarnings()
8042*da0073e9SAndroid Build Coastguard Worker            torch.FloatStorage()
8043*da0073e9SAndroid Build Coastguard Worker            torch.randn(10).storage()
8044*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
8045*da0073e9SAndroid Build Coastguard Worker            warning = w[0].message
8046*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(re.search(
8047*da0073e9SAndroid Build Coastguard Worker                '^TypedStorage is deprecated',
8048*da0073e9SAndroid Build Coastguard Worker                str(warning)))
8049*da0073e9SAndroid Build Coastguard Worker            # Check the line of code from the warning's stack
8050*da0073e9SAndroid Build Coastguard Worker            with open(w[0].filename, encoding="utf-8") as f:
8051*da0073e9SAndroid Build Coastguard Worker                code_line = f.readlines()[w[0].lineno - 1]
8052*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(re.search(re.escape('torch.FloatStorage()'), code_line))
8053*da0073e9SAndroid Build Coastguard Worker
8054*da0073e9SAndroid Build Coastguard Worker        # Check that warnings are not emitted if it happened in the past
8055*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
8056*da0073e9SAndroid Build Coastguard Worker            warnings.resetwarnings()
8057*da0073e9SAndroid Build Coastguard Worker            torch.FloatStorage()
8058*da0073e9SAndroid Build Coastguard Worker            torch.randn(10).storage()
8059*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0, msg=str([str(a) for a in w]))
8060*da0073e9SAndroid Build Coastguard Worker
8061*da0073e9SAndroid Build Coastguard Worker    def test_from_file(self):
8062*da0073e9SAndroid Build Coastguard Worker        def assert_with_filename(filename):
8063*da0073e9SAndroid Build Coastguard Worker            size = 10000
8064*da0073e9SAndroid Build Coastguard Worker            s1 = torch.FloatStorage.from_file(filename, True, size)
8065*da0073e9SAndroid Build Coastguard Worker            t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
8066*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1.data_ptr(), torch.FloatTensor(s1).data_ptr())
8067*da0073e9SAndroid Build Coastguard Worker
8068*da0073e9SAndroid Build Coastguard Worker            # check mapping
8069*da0073e9SAndroid Build Coastguard Worker            s2 = torch.FloatStorage.from_file(filename, True, size)
8070*da0073e9SAndroid Build Coastguard Worker            t2 = torch.FloatTensor(s2)
8071*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2, atol=0, rtol=0)
8072*da0073e9SAndroid Build Coastguard Worker
8073*da0073e9SAndroid Build Coastguard Worker            # check changes to t1 from t2
8074*da0073e9SAndroid Build Coastguard Worker            rnum = random.uniform(-1, 1)
8075*da0073e9SAndroid Build Coastguard Worker            t1.fill_(rnum)
8076*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2, atol=0, rtol=0)
8077*da0073e9SAndroid Build Coastguard Worker
8078*da0073e9SAndroid Build Coastguard Worker            # check changes to t2 from t1
8079*da0073e9SAndroid Build Coastguard Worker            rnum = random.uniform(-1, 1)
8080*da0073e9SAndroid Build Coastguard Worker            t2.fill_(rnum)
8081*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2, atol=0, rtol=0)
8082*da0073e9SAndroid Build Coastguard Worker
8083*da0073e9SAndroid Build Coastguard Worker            # release the tensors
8084*da0073e9SAndroid Build Coastguard Worker            del s1, t1, s2, t2
8085*da0073e9SAndroid Build Coastguard Worker
8086*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
8087*da0073e9SAndroid Build Coastguard Worker            assert_with_filename(fname)
8088*da0073e9SAndroid Build Coastguard Worker
8089*da0073e9SAndroid Build Coastguard Worker        if IS_FILESYSTEM_UTF8_ENCODING:
8090*da0073e9SAndroid Build Coastguard Worker            with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname, TemporaryFileName(dir=dname) as fname:
8091*da0073e9SAndroid Build Coastguard Worker                assert_with_filename(fname)
8092*da0073e9SAndroid Build Coastguard Worker
8093*da0073e9SAndroid Build Coastguard Worker    def test_torch_from_file(self):
8094*da0073e9SAndroid Build Coastguard Worker        def assert_with_filename(filename):
8095*da0073e9SAndroid Build Coastguard Worker            size = 10000
8096*da0073e9SAndroid Build Coastguard Worker            s1 = torch.from_file(filename, True, size, dtype=torch.float)
8097*da0073e9SAndroid Build Coastguard Worker            t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
8098*da0073e9SAndroid Build Coastguard Worker
8099*da0073e9SAndroid Build Coastguard Worker            # check mapping
8100*da0073e9SAndroid Build Coastguard Worker            s2 = torch.from_file(filename, True, size, dtype=torch.float)
8101*da0073e9SAndroid Build Coastguard Worker            t2 = torch.FloatTensor(s2)
8102*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2, atol=0, rtol=0)
8103*da0073e9SAndroid Build Coastguard Worker
8104*da0073e9SAndroid Build Coastguard Worker            # check changes to t1 from t2
8105*da0073e9SAndroid Build Coastguard Worker            rnum = random.uniform(-1, 1)
8106*da0073e9SAndroid Build Coastguard Worker            t1.fill_(rnum)
8107*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2, atol=0, rtol=0)
8108*da0073e9SAndroid Build Coastguard Worker
8109*da0073e9SAndroid Build Coastguard Worker            # check changes to t2 from t1
8110*da0073e9SAndroid Build Coastguard Worker            rnum = random.uniform(-1, 1)
8111*da0073e9SAndroid Build Coastguard Worker            t2.fill_(rnum)
8112*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1, t2, atol=0, rtol=0)
8113*da0073e9SAndroid Build Coastguard Worker
8114*da0073e9SAndroid Build Coastguard Worker            # release the tensors
8115*da0073e9SAndroid Build Coastguard Worker            del s1, t1, s2, t2
8116*da0073e9SAndroid Build Coastguard Worker
8117*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
8118*da0073e9SAndroid Build Coastguard Worker            assert_with_filename(fname)
8119*da0073e9SAndroid Build Coastguard Worker
8120*da0073e9SAndroid Build Coastguard Worker        if IS_FILESYSTEM_UTF8_ENCODING:
8121*da0073e9SAndroid Build Coastguard Worker            with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname, TemporaryFileName(dir=dname) as fname:
8122*da0073e9SAndroid Build Coastguard Worker                assert_with_filename(fname)
8123*da0073e9SAndroid Build Coastguard Worker
8124*da0073e9SAndroid Build Coastguard Worker    def test_print(self):
8125*da0073e9SAndroid Build Coastguard Worker        default_type = torch.tensor([]).type()
8126*da0073e9SAndroid Build Coastguard Worker        for t in torch._tensor_classes:
8127*da0073e9SAndroid Build Coastguard Worker            if t == torch.HalfTensor:
8128*da0073e9SAndroid Build Coastguard Worker                continue  # HalfTensor does not support fill
8129*da0073e9SAndroid Build Coastguard Worker            if t.is_sparse:
8130*da0073e9SAndroid Build Coastguard Worker                continue
8131*da0073e9SAndroid Build Coastguard Worker            if t.is_cuda and not torch.cuda.is_available():
8132*da0073e9SAndroid Build Coastguard Worker                continue
8133*da0073e9SAndroid Build Coastguard Worker            obj = t(100, 100).fill_(1)
8134*da0073e9SAndroid Build Coastguard Worker            obj.__repr__()
8135*da0073e9SAndroid Build Coastguard Worker            str(obj)
8136*da0073e9SAndroid Build Coastguard Worker        # test half tensor
8137*da0073e9SAndroid Build Coastguard Worker        obj = torch.rand(100, 100, device='cpu').half()
8138*da0073e9SAndroid Build Coastguard Worker        obj.__repr__()
8139*da0073e9SAndroid Build Coastguard Worker        str(obj)
8140*da0073e9SAndroid Build Coastguard Worker        for t in torch._storage_classes:
8141*da0073e9SAndroid Build Coastguard Worker            if t == torch.BFloat16Storage:
8142*da0073e9SAndroid Build Coastguard Worker                continue  # Fix once fill is enabled for bfloat16
8143*da0073e9SAndroid Build Coastguard Worker            if t.is_cuda and not torch.cuda.is_available():
8144*da0073e9SAndroid Build Coastguard Worker                continue
8145*da0073e9SAndroid Build Coastguard Worker            if t == torch.BoolStorage or t == torch.cuda.BoolStorage:
8146*da0073e9SAndroid Build Coastguard Worker                obj = t(100).fill_(True)
8147*da0073e9SAndroid Build Coastguard Worker            else:
8148*da0073e9SAndroid Build Coastguard Worker                obj = t(100).fill_(1)
8149*da0073e9SAndroid Build Coastguard Worker            obj.__repr__()
8150*da0073e9SAndroid Build Coastguard Worker            str(obj)
8151*da0073e9SAndroid Build Coastguard Worker
8152*da0073e9SAndroid Build Coastguard Worker        # test complex tensor
8153*da0073e9SAndroid Build Coastguard Worker        # complex tensor print uses two formatters, one for real values
8154*da0073e9SAndroid Build Coastguard Worker        # and the other for imag values. this is consistent with numpy
8155*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2.3 + 4j, 7 + 6j])
8156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8157*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([2.3000+4.j, 7.0000+6.j])''')
8158*da0073e9SAndroid Build Coastguard Worker
8159*da0073e9SAndroid Build Coastguard Worker        # test complex half tensor
8160*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.25 + 4j, -7. + 6j], dtype=torch.chalf)
8161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8162*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([ 1.2500+4.j, -7.0000+6.j], dtype=torch.complex32)''')
8163*da0073e9SAndroid Build Coastguard Worker
8164*da0073e9SAndroid Build Coastguard Worker        # test scientific notation for complex tensors
8165*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1e28 + 2j , -1e-28j])
8166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8167*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1.0000e+28+2.0000e+00j, -0.0000e+00-1.0000e-28j])''')
8168*da0073e9SAndroid Build Coastguard Worker
8169*da0073e9SAndroid Build Coastguard Worker        # test big integer
8170*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2341234123412341)
8171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8172*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor(2341234123412341)''')
8173*da0073e9SAndroid Build Coastguard Worker
8174*da0073e9SAndroid Build Coastguard Worker        # test scientific notation
8175*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1e28, 1e-28])
8176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8177*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''')
8178*da0073e9SAndroid Build Coastguard Worker
8179*da0073e9SAndroid Build Coastguard Worker        # test scientific notation using set_printoptions
8180*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1e2, 1e-2])
8181*da0073e9SAndroid Build Coastguard Worker        torch.set_printoptions(sci_mode=True)
8182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8183*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''')
8184*da0073e9SAndroid Build Coastguard Worker        torch.set_printoptions(sci_mode=False)
8185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8186*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([  100.0000,     0.0100])''')
8187*da0073e9SAndroid Build Coastguard Worker        torch.set_printoptions(sci_mode=None)  # reset to the default value
8188*da0073e9SAndroid Build Coastguard Worker
8189*da0073e9SAndroid Build Coastguard Worker        # test no leading space if all elements positive
8190*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1, 2])
8191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8192*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1, 2])''')
8193*da0073e9SAndroid Build Coastguard Worker
8194*da0073e9SAndroid Build Coastguard Worker        # test for leading space if there are negative elements
8195*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1, -2])
8196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8197*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([ 1, -2])''')
8198*da0073e9SAndroid Build Coastguard Worker
8199*da0073e9SAndroid Build Coastguard Worker        # test inf and nan
8200*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
8201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8202*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([4.0000,    inf, 1.5000,   -inf, 0.0000,    nan, 1.0000])''')
8203*da0073e9SAndroid Build Coastguard Worker
8204*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([4, inf, complex(1.5, inf), complex(-inf, 4), 0, complex(nan, inf), complex(3, nan)])
8205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.__repr__(), str(y))
8206*da0073e9SAndroid Build Coastguard Worker        expected_str = '''\
8207*da0073e9SAndroid Build Coastguard Workertensor([4.0000+0.j,    inf+0.j, 1.5000+infj,   -inf+4.j, 0.0000+0.j,    nan+infj,
8208*da0073e9SAndroid Build Coastguard Worker        3.0000+nanj])'''
8209*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(y), expected_str)
8210*da0073e9SAndroid Build Coastguard Worker
8211*da0073e9SAndroid Build Coastguard Worker        # test dtype
8212*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
8213*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
8214*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.__repr__(), str(x))
8215*da0073e9SAndroid Build Coastguard Worker            expected_str = '''\
8216*da0073e9SAndroid Build Coastguard Workertensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
8217*da0073e9SAndroid Build Coastguard Worker                inf], dtype=torch.float64)'''
8218*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(str(x), expected_str)
8219*da0073e9SAndroid Build Coastguard Worker
8220*da0073e9SAndroid Build Coastguard Worker        # test changing default dtype
8221*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float64):
8222*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.__repr__(), str(x))
8223*da0073e9SAndroid Build Coastguard Worker            expected_str = '''\
8224*da0073e9SAndroid Build Coastguard Workertensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
8225*da0073e9SAndroid Build Coastguard Worker                inf])'''
8226*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(str(x), expected_str)
8227*da0073e9SAndroid Build Coastguard Worker
8228*da0073e9SAndroid Build Coastguard Worker        # test summary
8229*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(10000)
8230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8231*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([0., 0., 0.,  ..., 0., 0., 0.])''')
8232*da0073e9SAndroid Build Coastguard Worker
8233*da0073e9SAndroid Build Coastguard Worker        # test internal summary function
8234*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 20, 5, 30)
8235*da0073e9SAndroid Build Coastguard Worker        summary = torch._tensor_str.get_summarized_data(x)
8236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(summary.shape, (1, 6, 5, 6))
8237*da0073e9SAndroid Build Coastguard Worker        first_and_last = [0, 1, 2, -3, -2, -1]
8238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(summary, x[:, first_and_last][..., first_and_last])
8239*da0073e9SAndroid Build Coastguard Worker
8240*da0073e9SAndroid Build Coastguard Worker        # test device
8241*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
8242*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([123], device='cuda:0')
8243*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.__repr__(), str(x))
8244*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')
8245*da0073e9SAndroid Build Coastguard Worker
8246*da0073e9SAndroid Build Coastguard Worker            # test changing default to cuda
8247*da0073e9SAndroid Build Coastguard Worker            torch.set_default_tensor_type(torch.cuda.FloatTensor)
8248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.__repr__(), str(x))
8249*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(str(x), '''tensor([123])''')
8250*da0073e9SAndroid Build Coastguard Worker
8251*da0073e9SAndroid Build Coastguard Worker            # test printing a tensor on a different gpu than current one.
8252*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() >= 2:
8253*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.device(1):
8254*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.__repr__(), str(x))
8255*da0073e9SAndroid Build Coastguard Worker                    self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')
8256*da0073e9SAndroid Build Coastguard Worker
8257*da0073e9SAndroid Build Coastguard Worker            # test printing cpu tensor when default device is cuda
8258*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([123], device='cpu')
8259*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.__repr__(), str(y))
8260*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(str(y), '''tensor([123], device='cpu')''')
8261*da0073e9SAndroid Build Coastguard Worker        torch.set_default_tensor_type(default_type)
8262*da0073e9SAndroid Build Coastguard Worker
8263*da0073e9SAndroid Build Coastguard Worker
8264*da0073e9SAndroid Build Coastguard Worker        # test integral floats and requires_grad
8265*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([123.], requires_grad=True)
8266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8267*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''')
8268*da0073e9SAndroid Build Coastguard Worker
8269*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous print
8270*da0073e9SAndroid Build Coastguard Worker        # sliced tensor should have > PRINT_OPTS.threshold elements
8271*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(100, 2, 2, 10)
8272*da0073e9SAndroid Build Coastguard Worker        y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
8273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(y), y.__repr__())
8274*da0073e9SAndroid Build Coastguard Worker        expected_str = '''\
8275*da0073e9SAndroid Build Coastguard Workertensor([[[1., 1., 1.,  ..., 1., 1., 1.],
8276*da0073e9SAndroid Build Coastguard Worker         [1., 1., 1.,  ..., 1., 1., 1.]],
8277*da0073e9SAndroid Build Coastguard Worker
8278*da0073e9SAndroid Build Coastguard Worker        [[1., 1., 1.,  ..., 1., 1., 1.],
8279*da0073e9SAndroid Build Coastguard Worker         [1., 1., 1.,  ..., 1., 1., 1.]],
8280*da0073e9SAndroid Build Coastguard Worker
8281*da0073e9SAndroid Build Coastguard Worker        [[1., 1., 1.,  ..., 1., 1., 1.],
8282*da0073e9SAndroid Build Coastguard Worker         [1., 1., 1.,  ..., 1., 1., 1.]],
8283*da0073e9SAndroid Build Coastguard Worker
8284*da0073e9SAndroid Build Coastguard Worker        ...,
8285*da0073e9SAndroid Build Coastguard Worker
8286*da0073e9SAndroid Build Coastguard Worker        [[1., 1., 1.,  ..., 1., 1., 1.],
8287*da0073e9SAndroid Build Coastguard Worker         [1., 1., 1.,  ..., 1., 1., 1.]],
8288*da0073e9SAndroid Build Coastguard Worker
8289*da0073e9SAndroid Build Coastguard Worker        [[1., 1., 1.,  ..., 1., 1., 1.],
8290*da0073e9SAndroid Build Coastguard Worker         [1., 1., 1.,  ..., 1., 1., 1.]],
8291*da0073e9SAndroid Build Coastguard Worker
8292*da0073e9SAndroid Build Coastguard Worker        [[1., 1., 1.,  ..., 1., 1., 1.],
8293*da0073e9SAndroid Build Coastguard Worker         [1., 1., 1.,  ..., 1., 1., 1.]]])\
8294*da0073e9SAndroid Build Coastguard Worker'''
8295*da0073e9SAndroid Build Coastguard Worker
8296*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(y), expected_str)
8297*da0073e9SAndroid Build Coastguard Worker
8298*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(100, 2, 2, 10) * (1 + 1j)
8299*da0073e9SAndroid Build Coastguard Worker        y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
8300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(y), y.__repr__())
8301*da0073e9SAndroid Build Coastguard Worker        expected_str = '''\
8302*da0073e9SAndroid Build Coastguard Workertensor([[[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8303*da0073e9SAndroid Build Coastguard Worker         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8304*da0073e9SAndroid Build Coastguard Worker
8305*da0073e9SAndroid Build Coastguard Worker        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8306*da0073e9SAndroid Build Coastguard Worker         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8307*da0073e9SAndroid Build Coastguard Worker
8308*da0073e9SAndroid Build Coastguard Worker        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8309*da0073e9SAndroid Build Coastguard Worker         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8310*da0073e9SAndroid Build Coastguard Worker
8311*da0073e9SAndroid Build Coastguard Worker        ...,
8312*da0073e9SAndroid Build Coastguard Worker
8313*da0073e9SAndroid Build Coastguard Worker        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8314*da0073e9SAndroid Build Coastguard Worker         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8315*da0073e9SAndroid Build Coastguard Worker
8316*da0073e9SAndroid Build Coastguard Worker        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8317*da0073e9SAndroid Build Coastguard Worker         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]],
8318*da0073e9SAndroid Build Coastguard Worker
8319*da0073e9SAndroid Build Coastguard Worker        [[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
8320*da0073e9SAndroid Build Coastguard Worker         [1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j]]])\
8321*da0073e9SAndroid Build Coastguard Worker'''
8322*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(y), expected_str)
8323*da0073e9SAndroid Build Coastguard Worker
8324*da0073e9SAndroid Build Coastguard Worker        # test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style
8325*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.00002)
8326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8327*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''')
8328*da0073e9SAndroid Build Coastguard Worker
8329*da0073e9SAndroid Build Coastguard Worker        # test print boolean tensor
8330*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([True])
8331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8332*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([True])''')
8333*da0073e9SAndroid Build Coastguard Worker
8334*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(True)
8335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8336*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor(True)''')
8337*da0073e9SAndroid Build Coastguard Worker
8338*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test print float in sci_mode when min < 0.0001.
8339*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.00002])
8340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8341*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''')
8342*da0073e9SAndroid Build Coastguard Worker
8343*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test print complex in sci_mode when real_min < 0.0001 and (or) imag_min < 0.0001.
8344*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.00002]) * (1 + 1j)
8345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8346*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([2.0000e-05+2.0000e-05j])''')
8347*da0073e9SAndroid Build Coastguard Worker
8348*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test print float in sci_mode when max > 1e8.
8349*da0073e9SAndroid Build Coastguard Worker        # TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific
8350*da0073e9SAndroid Build Coastguard Worker        # to do automatic trimming and padding.
8351*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([123456789.])
8352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8353*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''')
8354*da0073e9SAndroid Build Coastguard Worker
8355*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test print float in sci_mode when max / min > 1000.
8356*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.01, 11])
8357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8358*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''')
8359*da0073e9SAndroid Build Coastguard Worker
8360*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test print int max / min > 1000, no sci_mode
8361*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1, 1010])
8362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8363*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([   1, 1010])''')
8364*da0073e9SAndroid Build Coastguard Worker
8365*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test print int > 1e8, no sci_mode
8366*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1000000000])  # 1e9
8367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8368*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1000000000])''')
8369*da0073e9SAndroid Build Coastguard Worker
8370*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test printing float in int_mode
8371*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1., 1000.])
8372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8373*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([   1., 1000.])''')
8374*da0073e9SAndroid Build Coastguard Worker
8375*da0073e9SAndroid Build Coastguard Worker        # [Numpy] test printing float in int_mode in sci format when max / min > 1000.
8376*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1., 1010.])
8377*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.__repr__(), str(x))
8378*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')
8379*da0073e9SAndroid Build Coastguard Worker
8380*da0073e9SAndroid Build Coastguard Worker    def test_sizeof(self) -> None:
8381*da0073e9SAndroid Build Coastguard Worker        sizeof_empty = torch.randn(0).storage().__sizeof__()
8382*da0073e9SAndroid Build Coastguard Worker        sizeof_10 = torch.randn(10).storage().__sizeof__()
8383*da0073e9SAndroid Build Coastguard Worker        sizeof_100 = torch.randn(100).storage().__sizeof__()
8384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
8385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
8386*da0073e9SAndroid Build Coastguard Worker
8387*da0073e9SAndroid Build Coastguard Worker        sizeof_empty = torch.randn(0).to(torch.uint8).storage().__sizeof__()
8388*da0073e9SAndroid Build Coastguard Worker        sizeof_10 = torch.randn(10).to(torch.uint8).storage().__sizeof__()
8389*da0073e9SAndroid Build Coastguard Worker        sizeof_100 = torch.randn(100).to(torch.uint8).storage().__sizeof__()
8390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
8391*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
8392*da0073e9SAndroid Build Coastguard Worker
8393*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
8394*da0073e9SAndroid Build Coastguard Worker    def test_resizable(self) -> None:
8395*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5)
8396*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.storage().resizable())
8397*da0073e9SAndroid Build Coastguard Worker        x.numpy()
8398*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.storage().resizable())
8399*da0073e9SAndroid Build Coastguard Worker
8400*da0073e9SAndroid Build Coastguard Worker    def test_iter(self) -> None:
8401*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
8402*da0073e9SAndroid Build Coastguard Worker        for i, sub in enumerate(x):
8403*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sub, x[i])  # noqa: PLR1736
8404*da0073e9SAndroid Build Coastguard Worker
8405*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([])
8406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(x), [])
8407*da0073e9SAndroid Build Coastguard Worker
8408*da0073e9SAndroid Build Coastguard Worker    def test_new(self) -> None:
8409*da0073e9SAndroid Build Coastguard Worker        x = torch.autograd.Variable(torch.tensor([]))
8410*da0073e9SAndroid Build Coastguard Worker        y = torch.autograd.Variable(torch.randn(4, 4))
8411*da0073e9SAndroid Build Coastguard Worker        z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
8412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new().shape, [0])
8413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(), x)
8414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(1, 2).shape, [1, 2])
8415*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4])
8416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new([3, 4]).shape, [2])
8417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new([3, 4]).tolist(), [3, 4])
8418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new((3, 4)).tolist(), [3, 4])
8419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4])
8420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4])
8421*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4])
8422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(size=(3, 4)).shape, [3, 4])
8423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(()).shape, [0])
8424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr())
8425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new(y).data_ptr(), y.data_ptr())
8426*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(x.new(y), y)
8427*da0073e9SAndroid Build Coastguard Worker
8428*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: x.new(z))
8429*da0073e9SAndroid Build Coastguard Worker        # TypeError would be better
8430*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
8431*da0073e9SAndroid Build Coastguard Worker
8432*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
8433*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory(self):
8434*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 5)
8435*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_pinned())
8436*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
8437*da0073e9SAndroid Build Coastguard Worker            pinned = x.pin_memory()
8438*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pinned.is_pinned())
8439*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pinned, x)
8440*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
8441*da0073e9SAndroid Build Coastguard Worker            # test that pin_memory on already pinned tensor has no effect
8442*da0073e9SAndroid Build Coastguard Worker            self.assertIs(pinned, pinned.pin_memory())
8443*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
8444*da0073e9SAndroid Build Coastguard Worker
8445*da0073e9SAndroid Build Coastguard Worker    def test_error_msg_type_translation(self):
8446*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8447*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
8448*da0073e9SAndroid Build Coastguard Worker                # message includes both Double and Long
8449*da0073e9SAndroid Build Coastguard Worker                '(?=.*Double)(?=.*Long)'):
8450*da0073e9SAndroid Build Coastguard Worker
8451*da0073e9SAndroid Build Coastguard Worker            # Calls model with a LongTensor input but DoubleTensor weights
8452*da0073e9SAndroid Build Coastguard Worker            input = torch.zeros(1, 1, 1, 6, dtype=torch.long)
8453*da0073e9SAndroid Build Coastguard Worker            weight = torch.nn.Parameter(torch.zeros(1, 1, 1, 3, dtype=torch.double))
8454*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
8455*da0073e9SAndroid Build Coastguard Worker            model.weight = weight
8456*da0073e9SAndroid Build Coastguard Worker            out = model(input)
8457*da0073e9SAndroid Build Coastguard Worker
8458*da0073e9SAndroid Build Coastguard Worker    def test_apply(self):
8459*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1, 6)
8460*da0073e9SAndroid Build Coastguard Worker        res = x.clone().apply_(lambda k: k + k)
8461*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, x * 2)
8462*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: x.apply_(lambda k: "str"))
8463*da0073e9SAndroid Build Coastguard Worker
8464*da0073e9SAndroid Build Coastguard Worker    def test_map(self):
8465*da0073e9SAndroid Build Coastguard Worker        x = torch.autograd.Variable(torch.randn(3, 3))
8466*da0073e9SAndroid Build Coastguard Worker        y = torch.autograd.Variable(torch.randn(3))
8467*da0073e9SAndroid Build Coastguard Worker        res = x.clone()
8468*da0073e9SAndroid Build Coastguard Worker        res.map_(y, lambda a, b: a + b)
8469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, x + y)
8470*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError, "not callable", lambda: res.map_(y, "str"))
8471*da0073e9SAndroid Build Coastguard Worker
8472*da0073e9SAndroid Build Coastguard Worker    def test_map2(self):
8473*da0073e9SAndroid Build Coastguard Worker        x = torch.autograd.Variable(torch.randn(3, 3))
8474*da0073e9SAndroid Build Coastguard Worker        y = torch.autograd.Variable(torch.randn(3))
8475*da0073e9SAndroid Build Coastguard Worker        z = torch.autograd.Variable(torch.randn(1, 3))
8476*da0073e9SAndroid Build Coastguard Worker        res = x.clone()
8477*da0073e9SAndroid Build Coastguard Worker        res.map2_(y, z, lambda a, b, c: a + b * c)
8478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, x + y * z)
8479*da0073e9SAndroid Build Coastguard Worker        z.requires_grad = True
8480*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
8481*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "requires grad",
8482*da0073e9SAndroid Build Coastguard Worker            lambda: res.map2_(y, z, lambda a, b, c: a + b * c))
8483*da0073e9SAndroid Build Coastguard Worker
8484*da0073e9SAndroid Build Coastguard Worker    def test_Size(self):
8485*da0073e9SAndroid Build Coastguard Worker        x = torch.Size([1, 2, 3])
8486*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x, tuple)
8487*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[0], 1)
8488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1], 2)
8489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[2], 3)
8490*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(x), 3)
8491*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3)))
8492*da0073e9SAndroid Build Coastguard Worker
8493*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x * 2, torch.Size)
8494*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x[:-1], torch.Size)
8495*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x + x, torch.Size)
8496*da0073e9SAndroid Build Coastguard Worker
8497*da0073e9SAndroid Build Coastguard Worker    def test_Size_scalar(self):
8498*da0073e9SAndroid Build Coastguard Worker        three = torch.tensor(3)
8499*da0073e9SAndroid Build Coastguard Worker        two = torch.tensor(2)
8500*da0073e9SAndroid Build Coastguard Worker        x = torch.Size([0, 1, two, three, 4])
8501*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 5):
8502*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[i], i)
8503*da0073e9SAndroid Build Coastguard Worker
8504*da0073e9SAndroid Build Coastguard Worker    def test_Size_iter(self):
8505*da0073e9SAndroid Build Coastguard Worker        for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]:
8506*da0073e9SAndroid Build Coastguard Worker            x = torch.Size(sizes)
8507*da0073e9SAndroid Build Coastguard Worker            for i in range(0, 5):
8508*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x[i], i + 1)
8509*da0073e9SAndroid Build Coastguard Worker
8510*da0073e9SAndroid Build Coastguard Worker    def test_t_not_2d_error(self):
8511*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t())
8512*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_())
8513*da0073e9SAndroid Build Coastguard Worker
8514*da0073e9SAndroid Build Coastguard Worker    # skip this test for now as it affects all tests
8515*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "flush_denormal not supported")
8516*da0073e9SAndroid Build Coastguard Worker    def test_set_flush_denormal(self):
8517*da0073e9SAndroid Build Coastguard Worker        tiny_float = 1e-42
8518*da0073e9SAndroid Build Coastguard Worker        tiny_double = 1e-320
8519*da0073e9SAndroid Build Coastguard Worker        float_tensor = torch.FloatTensor([1.0, tiny_float])
8520*da0073e9SAndroid Build Coastguard Worker        double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
8521*da0073e9SAndroid Build Coastguard Worker
8522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0)
8523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float_tensor[1], tiny_float, atol=tiny_float / 16, rtol=0)
8524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0)
8525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0)
8526*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_tensor[2], tiny_double, atol=0.0, rtol=0)
8527*da0073e9SAndroid Build Coastguard Worker
8528*da0073e9SAndroid Build Coastguard Worker        torch.set_flush_denormal(True)
8529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0)
8530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float_tensor[1], 0.0, atol=0.0, rtol=0)  # tiny_float to zero
8531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0)
8532*da0073e9SAndroid Build Coastguard Worker        # tiny_float is not converted to zero in double type
8533*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0)
8534*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_tensor[2], 0.0, atol=0.0, rtol=0)  # tiny_double to zero
8535*da0073e9SAndroid Build Coastguard Worker        torch.set_flush_denormal(False)
8536*da0073e9SAndroid Build Coastguard Worker
8537*da0073e9SAndroid Build Coastguard Worker    def test_show_config(self):
8538*da0073e9SAndroid Build Coastguard Worker        # We can't usefully test the output; just make sure this doesn't crash
8539*da0073e9SAndroid Build Coastguard Worker        torch.__config__.show()
8540*da0073e9SAndroid Build Coastguard Worker
8541*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE, "CXX_FLAGS is only for OSS build.")
8542*da0073e9SAndroid Build Coastguard Worker    def test_cxx_flags(self):
8543*da0073e9SAndroid Build Coastguard Worker        torch.__config__._cxx_flags()
8544*da0073e9SAndroid Build Coastguard Worker
8545*da0073e9SAndroid Build Coastguard Worker    def test_parallel_info(self):
8546*da0073e9SAndroid Build Coastguard Worker        torch.__config__.parallel_info()
8547*da0073e9SAndroid Build Coastguard Worker
8548*da0073e9SAndroid Build Coastguard Worker    def test_get_cpu_capability(self):
8549*da0073e9SAndroid Build Coastguard Worker        # This method is primarily exposed for torchvision's resize
8550*da0073e9SAndroid Build Coastguard Worker        torch.backends.cpu.get_cpu_capability()
8551*da0073e9SAndroid Build Coastguard Worker
8552*da0073e9SAndroid Build Coastguard Worker        # We have to ensure that method is torchscriptable as torchvision's resize
8553*da0073e9SAndroid Build Coastguard Worker        # should be torchscriptable
8554*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(torch.backends.cpu.get_cpu_capability)
8555*da0073e9SAndroid Build Coastguard Worker
8556*da0073e9SAndroid Build Coastguard Worker    @slowTest
8557*da0073e9SAndroid Build Coastguard Worker    def test_slow_test(self):
8558*da0073e9SAndroid Build Coastguard Worker        # Just a smoketest to make sure our slowTest decorator works.
8559*da0073e9SAndroid Build Coastguard Worker        pass
8560*da0073e9SAndroid Build Coastguard Worker
8561*da0073e9SAndroid Build Coastguard Worker    def test_is_nonzero(self):
8562*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
8563*da0073e9SAndroid Build Coastguard Worker            torch.tensor([]).is_nonzero()
8564*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
8565*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 0]).is_nonzero()
8566*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor(0).is_nonzero())
8567*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor(1).is_nonzero())
8568*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor([0]).is_nonzero())
8569*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor([1]).is_nonzero())
8570*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor([[0]]).is_nonzero())
8571*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor([[1]]).is_nonzero())
8572*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor(0.1).is_nonzero())
8573*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor(-0.1).is_nonzero())
8574*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor(0.0).is_nonzero())
8575*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor(True).is_nonzero())
8576*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor(False).is_nonzero())
8577*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.tensor(0 + 0j).is_nonzero())
8578*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.tensor(0 + 0.1j).is_nonzero())
8579*da0073e9SAndroid Build Coastguard Worker
8580*da0073e9SAndroid Build Coastguard Worker    def test_assert_async(self):
8581*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
8582*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor([]))
8583*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
8584*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor([0, 0]))
8585*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8586*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor(0))
8587*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(1))
8588*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(0.1))
8589*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(-0.1))
8590*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8591*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor(0.0))
8592*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(True))
8593*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8594*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor(False))
8595*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(0 + 0.1j))
8596*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
8597*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor(0 + 0j))
8598*da0073e9SAndroid Build Coastguard Worker
8599*da0073e9SAndroid Build Coastguard Worker    # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA
8600*da0073e9SAndroid Build Coastguard Worker    # is available, we get a different error.
8601*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error")
8602*da0073e9SAndroid Build Coastguard Worker    def test_cuda_not_built(self):
8603*da0073e9SAndroid Build Coastguard Worker        msg = "Torch not compiled with CUDA enabled"
8604*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device())
8605*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda"))
8606*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda())
8607*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError, msg, lambda: torch.cuda.FloatTensor())
8608*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor))
8609*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda"))
8610*da0073e9SAndroid Build Coastguard Worker
8611*da0073e9SAndroid Build Coastguard Worker    def test_has_internal_overlap(self):
8612*da0073e9SAndroid Build Coastguard Worker        OVERLAP_NO = 0
8613*da0073e9SAndroid Build Coastguard Worker        OVERLAP_YES = 1
8614*da0073e9SAndroid Build Coastguard Worker        OVERLAP_TOO_HARD = 2
8615*da0073e9SAndroid Build Coastguard Worker
8616*da0073e9SAndroid Build Coastguard Worker        # Check for contiguous tensors
8617*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3)
8618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO)
8619*da0073e9SAndroid Build Coastguard Worker
8620*da0073e9SAndroid Build Coastguard Worker        # Checks for zero strides
8621*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1, 3)
8622*da0073e9SAndroid Build Coastguard Worker        b_expanded = b.expand(4, 3)
8623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES)
8624*da0073e9SAndroid Build Coastguard Worker
8625*da0073e9SAndroid Build Coastguard Worker        # Check for zero strided, size 1 axis, in non-contiguous storage (gh-33812)
8626*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(10).as_strided([2, 1, 5], [1, 0, 2])
8627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_NO)
8628*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(2, 1, 10)[::2].as_strided((2, 1, 5), (10, 0, 2))
8629*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_TOO_HARD)
8630*da0073e9SAndroid Build Coastguard Worker
8631*da0073e9SAndroid Build Coastguard Worker    def test_allow_tensor_metadata_change(self):
8632*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 3)
8633*da0073e9SAndroid Build Coastguard Worker        # Metadata changes are allowed on view tensors that are created from detach().
8634*da0073e9SAndroid Build Coastguard Worker
8635*da0073e9SAndroid Build Coastguard Worker    def test_memory_format(self):
8636*da0073e9SAndroid Build Coastguard Worker        def test_helper(x, memory_format):
8637*da0073e9SAndroid Build Coastguard Worker            y = x.contiguous(memory_format=memory_format)
8638*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
8639*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.is_contiguous(memory_format=memory_format))
8640*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, x)
8641*da0073e9SAndroid Build Coastguard Worker
8642*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.randn(4, 3, 8, 8), torch.channels_last)
8643*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.randn(4, 3, 8, 8, 8), torch.channels_last_3d)
8644*da0073e9SAndroid Build Coastguard Worker
8645*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self):
8646*da0073e9SAndroid Build Coastguard Worker        def test_helper(x, memory_format):
8647*da0073e9SAndroid Build Coastguard Worker            alias = x.contiguous(memory_format=memory_format)
8648*da0073e9SAndroid Build Coastguard Worker            alias.fill_(7)
8649*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, alias)
8650*da0073e9SAndroid Build Coastguard Worker
8651*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2), torch.channels_last)
8652*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.randn(4, 8, 8, 8, 3).permute(0, 4, 1, 2, 3), torch.channels_last_3d)
8653*da0073e9SAndroid Build Coastguard Worker
8654*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_empty(self):
8655*da0073e9SAndroid Build Coastguard Worker        def test_helper(dim1, dim2, memory_format):
8656*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
8657*da0073e9SAndroid Build Coastguard Worker                x = torch.empty(dim1, memory_format=memory_format)
8658*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(dim2, memory_format=memory_format)
8659*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_contiguous(memory_format=memory_format))
8660*da0073e9SAndroid Build Coastguard Worker
8661*da0073e9SAndroid Build Coastguard Worker        test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
8662*da0073e9SAndroid Build Coastguard Worker        test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)
8663*da0073e9SAndroid Build Coastguard Worker
8664*da0073e9SAndroid Build Coastguard Worker    def test_dim_order(self):
8665*da0073e9SAndroid Build Coastguard Worker        shape = (2, 3, 5, 7)
8666*da0073e9SAndroid Build Coastguard Worker
8667*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(shape)
8668*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple)
8669*da0073e9SAndroid Build Coastguard Worker        # transpose doesn't really change the underlying physical memory
8670*da0073e9SAndroid Build Coastguard Worker        # so expecting dim_order change to reflect that (like strides)
8671*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3))
8672*da0073e9SAndroid Build Coastguard Worker
8673*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(shape, memory_format=torch.channels_last)
8674*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 1))
8675*da0073e9SAndroid Build Coastguard Worker
8676*da0073e9SAndroid Build Coastguard Worker        t = torch.empty((2, 3, 5, 7, 8), memory_format=torch.channels_last_3d)
8677*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 4, 1))
8678*da0073e9SAndroid Build Coastguard Worker
8679*da0073e9SAndroid Build Coastguard Worker        for dim_order in itertools.permutations(range(4)):
8680*da0073e9SAndroid Build Coastguard Worker            self.assertSequenceEqual(
8681*da0073e9SAndroid Build Coastguard Worker                dim_order, torch.empty_permuted(shape, dim_order).dim_order()
8682*da0073e9SAndroid Build Coastguard Worker            )
8683*da0073e9SAndroid Build Coastguard Worker
8684*da0073e9SAndroid Build Coastguard Worker        for shape in [(2, 2, 2, 2), (2, 1, 2, 2), (2, 2, 1, 2), (2, 2, 2, 1), (2, 2, 1, 1), (2, 1, 1, 2)]:
8685*da0073e9SAndroid Build Coastguard Worker            for memory_format in (torch.contiguous_format, torch.channels_last):
8686*da0073e9SAndroid Build Coastguard Worker                t = torch.empty(shape).to(memory_format=memory_format)
8687*da0073e9SAndroid Build Coastguard Worker                if memory_format == torch.contiguous_format:
8688*da0073e9SAndroid Build Coastguard Worker                    dim_order_target = list(range(len(shape)))
8689*da0073e9SAndroid Build Coastguard Worker                elif memory_format == torch.channels_last:
8690*da0073e9SAndroid Build Coastguard Worker                    dim_order_target = [0, *list(range(2, len(shape))), 1]
8691*da0073e9SAndroid Build Coastguard Worker
8692*da0073e9SAndroid Build Coastguard Worker                self.assertSequenceEqual(dim_order_target, t.dim_order())
8693*da0073e9SAndroid Build Coastguard Worker
8694*da0073e9SAndroid Build Coastguard Worker    def test_subclass_tensors(self):
8695*da0073e9SAndroid Build Coastguard Worker        # raise an error when trying to subclass FloatTensor
8696*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
8697*da0073e9SAndroid Build Coastguard Worker            class Foo1(torch.FloatTensor):
8698*da0073e9SAndroid Build Coastguard Worker                pass
8699*da0073e9SAndroid Build Coastguard Worker
8700*da0073e9SAndroid Build Coastguard Worker        # but allow subclassing Tensor:
8701*da0073e9SAndroid Build Coastguard Worker        class Foo2(torch.Tensor):
8702*da0073e9SAndroid Build Coastguard Worker            def foo(self):
8703*da0073e9SAndroid Build Coastguard Worker                return 5
8704*da0073e9SAndroid Build Coastguard Worker        f = Foo2()
8705*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f.foo(), 5)
8706*da0073e9SAndroid Build Coastguard Worker
8707*da0073e9SAndroid Build Coastguard Worker    def test_ndim(self):
8708*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 3)
8709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, a.ndim)
8710*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(())
8711*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, b.ndim)
8712*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(1, 0)
8713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, c.ndim)
8714*da0073e9SAndroid Build Coastguard Worker
8715*da0073e9SAndroid Build Coastguard Worker    def test_nbytes(self):
8716*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 3, dtype=torch.float64)
8717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.numel() * a.element_size(), a.nbytes)
8718*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(())
8719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.numel() * b.element_size(), b.nbytes)
8720*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(1, 0)
8721*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.numel() * c.element_size(), c.nbytes)
8722*da0073e9SAndroid Build Coastguard Worker
8723*da0073e9SAndroid Build Coastguard Worker    def test_fill_diagonal(self):
8724*da0073e9SAndroid Build Coastguard Worker        a1 = torch.randn(7, 3)
8725*da0073e9SAndroid Build Coastguard Worker        a2 = a1.clone()
8726*da0073e9SAndroid Build Coastguard Worker        v = 1
8727*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
8728*da0073e9SAndroid Build Coastguard Worker            a2[i][i] = v
8729*da0073e9SAndroid Build Coastguard Worker        a1.fill_diagonal_(v)
8730*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a1, a2)
8731*da0073e9SAndroid Build Coastguard Worker
8732*da0073e9SAndroid Build Coastguard Worker        b1 = torch.randn(7, 3)
8733*da0073e9SAndroid Build Coastguard Worker        b2 = b1.clone()
8734*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
8735*da0073e9SAndroid Build Coastguard Worker            b2[i][i] = v
8736*da0073e9SAndroid Build Coastguard Worker            b2[i + 4][i] = v
8737*da0073e9SAndroid Build Coastguard Worker        b1.fill_diagonal_(v, wrap=True)
8738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b1, b2)
8739*da0073e9SAndroid Build Coastguard Worker
8740*da0073e9SAndroid Build Coastguard Worker        c1 = torch.rand(3, 3, 3)
8741*da0073e9SAndroid Build Coastguard Worker        c2 = c1.clone()
8742*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
8743*da0073e9SAndroid Build Coastguard Worker            c2[i][i][i] = v
8744*da0073e9SAndroid Build Coastguard Worker        c1.fill_diagonal_(v)
8745*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c1, c2)
8746*da0073e9SAndroid Build Coastguard Worker
8747*da0073e9SAndroid Build Coastguard Worker        # non-contiguous tensor
8748*da0073e9SAndroid Build Coastguard Worker        d1 = torch.rand(3, 3, 3)[:, 1, ...]
8749*da0073e9SAndroid Build Coastguard Worker        d2 = d1.clone()
8750*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
8751*da0073e9SAndroid Build Coastguard Worker            d2[i][i] = v
8752*da0073e9SAndroid Build Coastguard Worker        d1.fill_diagonal_(v)
8753*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, d2)
8754*da0073e9SAndroid Build Coastguard Worker
8755*da0073e9SAndroid Build Coastguard Worker        e1 = torch.rand(7, 3, 3)[:, 1, ...]
8756*da0073e9SAndroid Build Coastguard Worker        e2 = e1.clone()
8757*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
8758*da0073e9SAndroid Build Coastguard Worker            e2[i][i] = v
8759*da0073e9SAndroid Build Coastguard Worker            e2[i + 4][i] = v
8760*da0073e9SAndroid Build Coastguard Worker        e1.fill_diagonal_(v, wrap=True)
8761*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e1, e2)
8762*da0073e9SAndroid Build Coastguard Worker
8763*da0073e9SAndroid Build Coastguard Worker    def test_setting_real_imag_to_a_number(self):
8764*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.cfloat)
8765*da0073e9SAndroid Build Coastguard Worker        x.real = 0
8766*da0073e9SAndroid Build Coastguard Worker        x.imag = 0
8767*da0073e9SAndroid Build Coastguard Worker        zeros = torch.zeros(4)
8768*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.real, zeros)
8769*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.imag, zeros)
8770*da0073e9SAndroid Build Coastguard Worker
8771*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_cpu_inference(self):
8772*da0073e9SAndroid Build Coastguard Worker        # input nchw in (2,1,1,1), (2,2,2,2)
8773*da0073e9SAndroid Build Coastguard Worker        inputs = [
8774*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[[-0.5000]]], [[[0.5000]]]]),
8775*da0073e9SAndroid Build Coastguard Worker            torch.tensor([
8776*da0073e9SAndroid Build Coastguard Worker                [
8777*da0073e9SAndroid Build Coastguard Worker                    [[-0.5000, 0.5000], [-1.0000, 1.0000]],
8778*da0073e9SAndroid Build Coastguard Worker                    [[-0.2500, -0.5000], [0.2500, 0.5000]]
8779*da0073e9SAndroid Build Coastguard Worker                ],
8780*da0073e9SAndroid Build Coastguard Worker                [
8781*da0073e9SAndroid Build Coastguard Worker                    [[0.1000, 1.0000], [1.0000, 0.1000]],
8782*da0073e9SAndroid Build Coastguard Worker                    [[1.0000, 0.5000], [1.5000, -1.5000]]
8783*da0073e9SAndroid Build Coastguard Worker                ]])]
8784*da0073e9SAndroid Build Coastguard Worker        # output nchw in (2,1,1,1), (2,2,2,2)
8785*da0073e9SAndroid Build Coastguard Worker        outputs = [
8786*da0073e9SAndroid Build Coastguard Worker            torch.tensor([
8787*da0073e9SAndroid Build Coastguard Worker                [[[-0.499997496604919433593750000]]],
8788*da0073e9SAndroid Build Coastguard Worker                [[[0.499997496604919433593750000]]]]),
8789*da0073e9SAndroid Build Coastguard Worker            torch.tensor([
8790*da0073e9SAndroid Build Coastguard Worker                [[[-0.499997496604919433593750000, 0.499997496604919433593750000],
8791*da0073e9SAndroid Build Coastguard Worker                  [-0.999994993209838867187500000, 0.999994993209838867187500000]],
8792*da0073e9SAndroid Build Coastguard Worker                 [[-0.249998748302459716796875000, -0.499997496604919433593750000],
8793*da0073e9SAndroid Build Coastguard Worker                  [0.249998748302459716796875000, 0.499997496604919433593750000]]],
8794*da0073e9SAndroid Build Coastguard Worker                [[[0.099999502301216125488281250, 0.999994993209838867187500000],
8795*da0073e9SAndroid Build Coastguard Worker                  [0.999994993209838867187500000, 0.099999502301216125488281250]],
8796*da0073e9SAndroid Build Coastguard Worker                 [[0.999994993209838867187500000, 0.499997496604919433593750000],
8797*da0073e9SAndroid Build Coastguard Worker                  [1.499992489814758300781250000, -1.499992489814758300781250000]]]])]
8798*da0073e9SAndroid Build Coastguard Worker
8799*da0073e9SAndroid Build Coastguard Worker
8800*da0073e9SAndroid Build Coastguard Worker        for i in range(len(inputs)):
8801*da0073e9SAndroid Build Coastguard Worker            for affine in [False, True]:
8802*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.BatchNorm2d(inputs[i].size()[1], 1e-05, 0.1, affine=affine)
8803*da0073e9SAndroid Build Coastguard Worker                m.eval()
8804*da0073e9SAndroid Build Coastguard Worker                # contiguous case
8805*da0073e9SAndroid Build Coastguard Worker                input1 = inputs[i].contiguous()
8806*da0073e9SAndroid Build Coastguard Worker                output1 = m(input1)
8807*da0073e9SAndroid Build Coastguard Worker                # non-contiguous case
8808*da0073e9SAndroid Build Coastguard Worker                input2 = input1.permute(0, 1, 3, 2)
8809*da0073e9SAndroid Build Coastguard Worker                output2 = m(input2).permute(0, 1, 3, 2)
8810*da0073e9SAndroid Build Coastguard Worker                # channels last case
8811*da0073e9SAndroid Build Coastguard Worker                input3 = input1.contiguous(memory_format=torch.channels_last)
8812*da0073e9SAndroid Build Coastguard Worker                output3 = m(input3)
8813*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output3, outputs[i])
8814*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output3, output1)
8815*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output3, output2)
8816*da0073e9SAndroid Build Coastguard Worker
8817*da0073e9SAndroid Build Coastguard Worker    # FIXME: move these meta tests to their own test suite/class or
8818*da0073e9SAndroid Build Coastguard Worker    #   distribute them among the appropriate test suites for their ops
8819*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Fails after Triton update, see https://github.com/pytorch/pytorch/issues/94687")
8820*da0073e9SAndroid Build Coastguard Worker    def test_empty_meta(self):
8821*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2 ** 20, 2 ** 20, device='meta')
8822*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2 ** 20, device='meta')
8823*da0073e9SAndroid Build Coastguard Worker        z = x + y
8824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
8825*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: z[0][0].item())
8826*da0073e9SAndroid Build Coastguard Worker
8827*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Fails after Triton update, see https://github.com/pytorch/pytorch/issues/94687")
8828*da0073e9SAndroid Build Coastguard Worker    def test_format_scalar_meta(self):
8829*da0073e9SAndroid Build Coastguard Worker        x = torch.empty((), device='meta')
8830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(format(x), repr(x))
8831*da0073e9SAndroid Build Coastguard Worker
8832*da0073e9SAndroid Build Coastguard Worker    def test_upsample_nearest1d_meta(self):
8833*da0073e9SAndroid Build Coastguard Worker        # TODO: this test should be triggered by test_nn.py but right
8834*da0073e9SAndroid Build Coastguard Worker        # now meta is not enabled (and even if it was, we are probably
8835*da0073e9SAndroid Build Coastguard Worker        # missing too many meta functions to get through the test unmolested)
8836*da0073e9SAndroid Build Coastguard Worker
8837*da0073e9SAndroid Build Coastguard Worker        # NB: Can't make the exponent too big, or it will overflow
8838*da0073e9SAndroid Build Coastguard Worker        # signed 64-bit integer
8839*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2 * 10 ** 8, 3, 2 * 10 ** 8, device='meta')
8840*da0073e9SAndroid Build Coastguard Worker        z = torch.nn.functional.interpolate(x, scale_factor=2)
8841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
8842*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
8843*da0073e9SAndroid Build Coastguard Worker
8844*da0073e9SAndroid Build Coastguard Worker        # TODO: the out tests cannot be triggered by test_nn.py because
8845*da0073e9SAndroid Build Coastguard Worker        # we don't actually do out= arguments for nn functions, so there
8846*da0073e9SAndroid Build Coastguard Worker        # is no public API by which to get the out version
8847*da0073e9SAndroid Build Coastguard Worker
8848*da0073e9SAndroid Build Coastguard Worker        # interpolate doesn't seem to support out=
8849*da0073e9SAndroid Build Coastguard Worker        # (not sure why passing None here doesn't work? How strange...)
8850*da0073e9SAndroid Build Coastguard Worker        z = torch.empty(0, device='meta')
8851*da0073e9SAndroid Build Coastguard Worker        torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
8852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
8853*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
8854*da0073e9SAndroid Build Coastguard Worker
8855*da0073e9SAndroid Build Coastguard Worker    def test_upsample_nearest2d_meta(self):
8856*da0073e9SAndroid Build Coastguard Worker        # TODO: the out tests cannot be triggered by test_nn.py because
8857*da0073e9SAndroid Build Coastguard Worker        # we don't actually do out= arguments for nn functions, so there
8858*da0073e9SAndroid Build Coastguard Worker        # is no public API by which to get the out version
8859*da0073e9SAndroid Build Coastguard Worker
8860*da0073e9SAndroid Build Coastguard Worker        # Make sure we don't clobber strides of out tensor.  NB: this
8861*da0073e9SAndroid Build Coastguard Worker        # test must be done on 2d/3d, because 1d doesn't have any meaningful
8862*da0073e9SAndroid Build Coastguard Worker        # layout support
8863*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(4, 3, 8, 8, device='meta')
8864*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(4, 3, 16, 16, device='meta', memory_format=torch.channels_last)
8865*da0073e9SAndroid Build Coastguard Worker        torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
8866*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
8867*da0073e9SAndroid Build Coastguard Worker
8868*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last)
8869*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(4, 3, 16, 16, device='meta')
8870*da0073e9SAndroid Build Coastguard Worker        torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
8871*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous())
8872*da0073e9SAndroid Build Coastguard Worker
8873*da0073e9SAndroid Build Coastguard Worker        # But if resize occurs, do clobber
8874*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last)
8875*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, device='meta')
8876*da0073e9SAndroid Build Coastguard Worker        torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
8877*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
8878*da0073e9SAndroid Build Coastguard Worker
8879*da0073e9SAndroid Build Coastguard Worker        # Complain if out dtype mismatch
8880*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(4, 3, 8, 8, device='meta', dtype=torch.float)
8881*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
8882*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
8883*da0073e9SAndroid Build Coastguard Worker            RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
8884*da0073e9SAndroid Build Coastguard Worker            """Expected out tensor to have dtype torch.float32 but got torch.float64 instead"""
8885*da0073e9SAndroid Build Coastguard Worker        )
8886*da0073e9SAndroid Build Coastguard Worker
8887*da0073e9SAndroid Build Coastguard Worker        # Complain if out device mismatch
8888*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(0, 3, 8, 8, device='meta')
8889*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(0, 3, 16, 16, device='cpu')
8890*da0073e9SAndroid Build Coastguard Worker        # FIXME: compiling should properly error with a device mismatch.
8891*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_TORCHINDUCTOR:
8892*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedRaisesInline(
8893*da0073e9SAndroid Build Coastguard Worker                RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
8894*da0073e9SAndroid Build Coastguard Worker                """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!"""
8895*da0073e9SAndroid Build Coastguard Worker            )
8896*da0073e9SAndroid Build Coastguard Worker
8897*da0073e9SAndroid Build Coastguard Worker    def test_add_meta_scalar(self):
8898*da0073e9SAndroid Build Coastguard Worker        # From https://github.com/pytorch/pytorch/issues/53815
8899*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, device='meta')
8900*da0073e9SAndroid Build Coastguard Worker        y = x + 2
8901*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.size(), x.size())
8902*da0073e9SAndroid Build Coastguard Worker
8903*da0073e9SAndroid Build Coastguard Worker    def test_normal_shape(self):
8904*da0073e9SAndroid Build Coastguard Worker        warned = False
8905*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
8906*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.rand(1, device=device)
8907*da0073e9SAndroid Build Coastguard Worker            tensor4 = torch.rand(4, device=device)
8908*da0073e9SAndroid Build Coastguard Worker            tensor120 = torch.rand(120, device=device)
8909*da0073e9SAndroid Build Coastguard Worker            tensor2145 = torch.rand(2, 1, 4, 5, device=device)
8910*da0073e9SAndroid Build Coastguard Worker            tensor2345 = torch.rand(2, 3, 4, 5, device=device)
8911*da0073e9SAndroid Build Coastguard Worker            tensor2345_non_contiguous = torch.rand(2, 4, 3, 5, device=device).permute(0, 2, 1, 3)
8912*da0073e9SAndroid Build Coastguard Worker            tensor2345_channels_last = tensor2345.contiguous(memory_format=torch.channels_last)
8913*da0073e9SAndroid Build Coastguard Worker            output2345 = torch.zeros(2, 3, 4, 5, device=device)
8914*da0073e9SAndroid Build Coastguard Worker            output345 = torch.zeros(3, 4, 5, device=device)
8915*da0073e9SAndroid Build Coastguard Worker
8916*da0073e9SAndroid Build Coastguard Worker            # inputs have same size
8917*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345, tensor2345).size(), (2, 3, 4, 5))
8918*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345).size(), (2, 3, 4, 5))
8919*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345, tensor2345_channels_last).size(), (2, 3, 4, 5))
8920*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345_channels_last).size(), (2, 3, 4, 5))
8921*da0073e9SAndroid Build Coastguard Worker
8922*da0073e9SAndroid Build Coastguard Worker            # scalar case
8923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345, 2).size(), (2, 3, 4, 5))
8924*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(2, tensor2345).size(), (2, 3, 4, 5))
8925*da0073e9SAndroid Build Coastguard Worker
8926*da0073e9SAndroid Build Coastguard Worker            # inputs are expandable tensors
8927*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345, tensor1).size(), (2, 3, 4, 5))
8928*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2145, tensor2345).size(), (2, 3, 4, 5))
8929*da0073e9SAndroid Build Coastguard Worker
8930*da0073e9SAndroid Build Coastguard Worker            # inputs are non-expandable tensors, but they have same number of elements
8931*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
8932*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
8933*da0073e9SAndroid Build Coastguard Worker                    r"The size of tensor a \(120\) must match the size of "
8934*da0073e9SAndroid Build Coastguard Worker                    r"tensor b \(5\) at non-singleton dimension 3"):
8935*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,))
8936*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
8937*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
8938*da0073e9SAndroid Build Coastguard Worker                    r"The size of tensor a \(5\) must match the size of "
8939*da0073e9SAndroid Build Coastguard Worker                    r"tensor b \(120\) at non-singleton dimension 3"):
8940*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.normal(tensor2345, tensor120).size(), (2, 3, 4, 5))
8941*da0073e9SAndroid Build Coastguard Worker
8942*da0073e9SAndroid Build Coastguard Worker            # inputs are non-expandable tensors and they don't have same number of elements
8943*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
8944*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
8945*da0073e9SAndroid Build Coastguard Worker                    r"The size of tensor a \(5\) must match the size of "
8946*da0073e9SAndroid Build Coastguard Worker                    r"tensor b \(4\) at non-singleton dimension 3"):
8947*da0073e9SAndroid Build Coastguard Worker                torch.normal(tensor2345, tensor4)
8948*da0073e9SAndroid Build Coastguard Worker
8949*da0073e9SAndroid Build Coastguard Worker            # output and inputs are size compatible
8950*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.normal(tensor2345, tensor2345, out=output2345).size(), (2, 3, 4, 5))
8951*da0073e9SAndroid Build Coastguard Worker
8952*da0073e9SAndroid Build Coastguard Worker            # output and inputs are not size compatible
8953*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(
8954*da0073e9SAndroid Build Coastguard Worker                    UserWarning,
8955*da0073e9SAndroid Build Coastguard Worker                    "This behavior is deprecated, and in a future PyTorch "
8956*da0073e9SAndroid Build Coastguard Worker                    "release outputs will not be resized unless they have "
8957*da0073e9SAndroid Build Coastguard Worker                    "zero elements"):
8958*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.normal(tensor2345, tensor2145, out=output345).size(), (2, 3, 4, 5))
8959*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
8960*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
8961*da0073e9SAndroid Build Coastguard Worker                    r"The size of tensor a \(5\) must match the size of "
8962*da0073e9SAndroid Build Coastguard Worker                    r"tensor b \(120\) at non-singleton dimension 3"):
8963*da0073e9SAndroid Build Coastguard Worker                # inputs are not expandable, output size is not the same as mean
8964*da0073e9SAndroid Build Coastguard Worker                torch.normal(tensor2345, tensor120, out=output345)
8965*da0073e9SAndroid Build Coastguard Worker
8966*da0073e9SAndroid Build Coastguard Worker    def test_tensoriterator_output_setup(self):
8967*da0073e9SAndroid Build Coastguard Worker        # Test whether the output's memory layout is correct
8968*da0073e9SAndroid Build Coastguard Worker        def test_memory_layout(x, y, scale, zero_point, out):
8969*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dim(), 4)
8970*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.size(), y.size())
8971*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.size(), out.size())
8972*da0073e9SAndroid Build Coastguard Worker
8973*da0073e9SAndroid Build Coastguard Worker            shape = x.size()
8974*da0073e9SAndroid Build Coastguard Worker            for n in range(shape[0]):
8975*da0073e9SAndroid Build Coastguard Worker                for c in range(shape[1]):
8976*da0073e9SAndroid Build Coastguard Worker                    for h in range(shape[2]):
8977*da0073e9SAndroid Build Coastguard Worker                        for w in range(shape[3]):
8978*da0073e9SAndroid Build Coastguard Worker                            if scale is not None and zero_point is not None:
8979*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(
8980*da0073e9SAndroid Build Coastguard Worker                                    out[n][c][h][w],
8981*da0073e9SAndroid Build Coastguard Worker                                    torch.ops.quantized.add(x[n][c][h][w], y[n][c][h][w], scale, zero_point))
8982*da0073e9SAndroid Build Coastguard Worker                            else:
8983*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(out[n][c][h][w], x[n][c][h][w] + y[n][c][h][w])
8984*da0073e9SAndroid Build Coastguard Worker
8985*da0073e9SAndroid Build Coastguard Worker        xraw = torch.rand(2, 3, 4, 4)
8986*da0073e9SAndroid Build Coastguard Worker        yraw = torch.rand(2, 3, 4, 4)
8987*da0073e9SAndroid Build Coastguard Worker        qxraw = torch.quantize_per_tensor(xraw, 0.1, 5, torch.quint8)
8988*da0073e9SAndroid Build Coastguard Worker        qyraw = torch.quantize_per_tensor(yraw, 0.1, 5, torch.quint8)
8989*da0073e9SAndroid Build Coastguard Worker
8990*da0073e9SAndroid Build Coastguard Worker        # contiguous case fast setup
8991*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(xraw, yraw, None, None, xraw + yraw)
8992*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(qxraw, qyraw, 0.1, 5, torch.ops.quantized.add(qxraw, qyraw, 0.1, 5))
8993*da0073e9SAndroid Build Coastguard Worker
8994*da0073e9SAndroid Build Coastguard Worker        # channels last case fast setup
8995*da0073e9SAndroid Build Coastguard Worker        x = xraw.contiguous(memory_format=torch.channels_last)
8996*da0073e9SAndroid Build Coastguard Worker        y = yraw.contiguous(memory_format=torch.channels_last)
8997*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(x, y, None, None, x + y)
8998*da0073e9SAndroid Build Coastguard Worker        qx = qxraw.contiguous(memory_format=torch.channels_last)
8999*da0073e9SAndroid Build Coastguard Worker        qy = qyraw.contiguous(memory_format=torch.channels_last)
9000*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
9001*da0073e9SAndroid Build Coastguard Worker
9002*da0073e9SAndroid Build Coastguard Worker        # non contiguous case fast setup (dense, non-overlapping, same shape and strides)
9003*da0073e9SAndroid Build Coastguard Worker        x = xraw.permute(0, 2, 3, 1)
9004*da0073e9SAndroid Build Coastguard Worker        y = yraw.permute(0, 2, 3, 1)
9005*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(x, y, None, None, x + y)
9006*da0073e9SAndroid Build Coastguard Worker        qx = qxraw.permute(0, 2, 3, 1)
9007*da0073e9SAndroid Build Coastguard Worker        qy = qyraw.permute(0, 2, 3, 1)
9008*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
9009*da0073e9SAndroid Build Coastguard Worker
9010*da0073e9SAndroid Build Coastguard Worker        # non contiguous case fast setup (dense, non-overlapping)
9011*da0073e9SAndroid Build Coastguard Worker        # input tensors have same shape and strides
9012*da0073e9SAndroid Build Coastguard Worker        # output tensor have same shape as input tensors but different stride
9013*da0073e9SAndroid Build Coastguard Worker        # output tensor should preserve its strides in this case
9014*da0073e9SAndroid Build Coastguard Worker        x = xraw.permute(0, 2, 3, 1)
9015*da0073e9SAndroid Build Coastguard Worker        y = yraw.permute(0, 2, 3, 1)
9016*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(xraw)
9017*da0073e9SAndroid Build Coastguard Worker        out = out.permute(0, 3, 2, 1)
9018*da0073e9SAndroid Build Coastguard Worker        expected_stride = out.stride()
9019*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(x, y, None, None, torch.add(x, y, out=out))
9020*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_stride, out.stride())
9021*da0073e9SAndroid Build Coastguard Worker
9022*da0073e9SAndroid Build Coastguard Worker        # non contiguous case non fast setup
9023*da0073e9SAndroid Build Coastguard Worker        x = xraw.permute(0, 2, 3, 1)
9024*da0073e9SAndroid Build Coastguard Worker        y = yraw.permute(0, 3, 2, 1)
9025*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(x, y, None, None, x + y)
9026*da0073e9SAndroid Build Coastguard Worker        qx = qxraw.permute(0, 2, 3, 1)
9027*da0073e9SAndroid Build Coastguard Worker        qy = qyraw.permute(0, 3, 2, 1)
9028*da0073e9SAndroid Build Coastguard Worker        test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
9029*da0073e9SAndroid Build Coastguard Worker
9030*da0073e9SAndroid Build Coastguard Worker    # Tests to make sure we still handle .data properly until it is removed
9031*da0073e9SAndroid Build Coastguard Worker    def test_dot_data_use(self):
9032*da0073e9SAndroid Build Coastguard Worker        # .data allows to change the Tensors types inplace, check that we still
9033*da0073e9SAndroid Build Coastguard Worker        # raise a nice error.
9034*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
9035*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
9036*da0073e9SAndroid Build Coastguard Worker                # message includes both Double and ComplexFloat
9037*da0073e9SAndroid Build Coastguard Worker                '(?=.*Double)(?=.*ComplexFloat)'):
9038*da0073e9SAndroid Build Coastguard Worker
9039*da0073e9SAndroid Build Coastguard Worker            # Calls model with a LongTensor input but DoubleTensor weights
9040*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 1, 1, 6, dtype=torch.double)
9041*da0073e9SAndroid Build Coastguard Worker            weight = torch.zeros(1, 1, 1, 3, dtype=torch.complex64)
9042*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
9043*da0073e9SAndroid Build Coastguard Worker            model.weight.data = weight
9044*da0073e9SAndroid Build Coastguard Worker            out = model(input)
9045*da0073e9SAndroid Build Coastguard Worker
9046*da0073e9SAndroid Build Coastguard Worker    def test_empty_storage_view(self):
9047*da0073e9SAndroid Build Coastguard Worker        # we should be able to "modify" slices of a 0-element
9048*da0073e9SAndroid Build Coastguard Worker        # array without an error being raised due to
9049*da0073e9SAndroid Build Coastguard Worker        # trying to resize its storage
9050*da0073e9SAndroid Build Coastguard Worker        t = torch.from_numpy(np.empty((0, 4)))
9051*da0073e9SAndroid Build Coastguard Worker        t[:, 1::2] *= 1
9052*da0073e9SAndroid Build Coastguard Worker
9053*da0073e9SAndroid Build Coastguard Worker    def test_has_storage(self):
9054*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(torch.tensor([]).storage())
9055*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(torch.empty(0).storage())
9056*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(torch.tensor([]).clone().storage())
9057*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(torch.tensor([0, 0, 0]).nonzero().storage())
9058*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(torch.tensor([]).new().storage())
9059*da0073e9SAndroid Build Coastguard Worker
9060*da0073e9SAndroid Build Coastguard Worker    # FIXME: Extend this test and put in a TensorProperties test class
9061*da0073e9SAndroid Build Coastguard Worker    def test_numel(self):
9062*da0073e9SAndroid Build Coastguard Worker        b = torch.ByteTensor(3, 100, 100)
9063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.nelement(), 3 * 100 * 100)
9064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.numel(), 3 * 100 * 100)
9065*da0073e9SAndroid Build Coastguard Worker
9066*da0073e9SAndroid Build Coastguard Worker    # Verifies that (deep)copies of dtypes are the same objects
9067*da0073e9SAndroid Build Coastguard Worker    def test_copy_dtypes(self):
9068*da0073e9SAndroid Build Coastguard Worker        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
9069*da0073e9SAndroid Build Coastguard Worker            copied_dtype = copy.deepcopy(dtype)
9070*da0073e9SAndroid Build Coastguard Worker            self.assertIs(dtype, copied_dtype)
9071*da0073e9SAndroid Build Coastguard Worker
9072*da0073e9SAndroid Build Coastguard Worker    def test_dtype_is_signed(self):
9073*da0073e9SAndroid Build Coastguard Worker        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.half):
9074*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype)))
9075*da0073e9SAndroid Build Coastguard Worker
9076*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.quint8.is_signed)
9077*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint8.is_signed)
9078*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint32.is_signed)
9079*da0073e9SAndroid Build Coastguard Worker
9080*da0073e9SAndroid Build Coastguard Worker    # FIXME: Put the following random tests into their own test class or test suite
9081*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9082*da0073e9SAndroid Build Coastguard Worker    def test_RNGState(self):
9083*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
9084*da0073e9SAndroid Build Coastguard Worker        stateCloned = state.clone()
9085*da0073e9SAndroid Build Coastguard Worker        before = torch.rand(1000)
9086*da0073e9SAndroid Build Coastguard Worker
9087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(state.ne(stateCloned).long().sum(), 0, atol=0, rtol=0)
9088*da0073e9SAndroid Build Coastguard Worker
9089*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
9090*da0073e9SAndroid Build Coastguard Worker        after = torch.rand(1000)
9091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(before, after, atol=0, rtol=0)
9092*da0073e9SAndroid Build Coastguard Worker
9093*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9094*da0073e9SAndroid Build Coastguard Worker    def test_RNGStateAliasing(self):
9095*da0073e9SAndroid Build Coastguard Worker        # Fork the random number stream at this point
9096*da0073e9SAndroid Build Coastguard Worker        gen = torch.Generator()
9097*da0073e9SAndroid Build Coastguard Worker        gen.set_state(torch.get_rng_state())
9098*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gen.get_state(), torch.get_rng_state())
9099*da0073e9SAndroid Build Coastguard Worker
9100*da0073e9SAndroid Build Coastguard Worker        target_value = torch.rand(1000)
9101*da0073e9SAndroid Build Coastguard Worker        # Dramatically alter the internal state of the main generator
9102*da0073e9SAndroid Build Coastguard Worker        _ = torch.rand(100000)
9103*da0073e9SAndroid Build Coastguard Worker        forked_value = torch.rand(1000, generator=gen)
9104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(target_value, forked_value, atol=0, rtol=0, msg="RNG has not forked correctly.")
9105*da0073e9SAndroid Build Coastguard Worker
9106*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9107*da0073e9SAndroid Build Coastguard Worker    def test_RNG_after_pickle(self):
9108*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(100)
9109*da0073e9SAndroid Build Coastguard Worker        before = torch.rand(10)
9110*da0073e9SAndroid Build Coastguard Worker
9111*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(100)
9112*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
9113*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([1, 2, 3])
9114*da0073e9SAndroid Build Coastguard Worker        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor)
9115*da0073e9SAndroid Build Coastguard Worker        after = torch.rand(10)
9116*da0073e9SAndroid Build Coastguard Worker
9117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(before, after, atol=0, rtol=0)
9118*da0073e9SAndroid Build Coastguard Worker
9119*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9120*da0073e9SAndroid Build Coastguard Worker    def test_boxMullerState(self):
9121*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
9122*da0073e9SAndroid Build Coastguard Worker        odd_number = 101
9123*da0073e9SAndroid Build Coastguard Worker        seeded = torch.randn(odd_number)
9124*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
9125*da0073e9SAndroid Build Coastguard Worker        midstream = torch.randn(odd_number)
9126*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
9127*da0073e9SAndroid Build Coastguard Worker        repeat_midstream = torch.randn(odd_number)
9128*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
9129*da0073e9SAndroid Build Coastguard Worker        reseeded = torch.randn(odd_number)
9130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(midstream, repeat_midstream, atol=0, rtol=0,
9131*da0073e9SAndroid Build Coastguard Worker                         msg='get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
9132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seeded, reseeded, atol=0, rtol=0,
9133*da0073e9SAndroid Build Coastguard Worker                         msg='repeated calls to manual_seed not generating same sequence of normally distributed numbers')
9134*da0073e9SAndroid Build Coastguard Worker
9135*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098")
9136*da0073e9SAndroid Build Coastguard Worker    def test_manual_seed(self):
9137*da0073e9SAndroid Build Coastguard Worker        rng_state = torch.get_rng_state()
9138*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(2)
9139*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100)
9140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.initial_seed(), 2)
9141*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(2)
9142*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(100)
9143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
9144*da0073e9SAndroid Build Coastguard Worker
9145*da0073e9SAndroid Build Coastguard Worker        max_int64 = 0x7fff_ffff_ffff_ffff
9146*da0073e9SAndroid Build Coastguard Worker        min_int64 = -max_int64 - 1
9147*da0073e9SAndroid Build Coastguard Worker        max_uint64 = 0xffff_ffff_ffff_ffff
9148*da0073e9SAndroid Build Coastguard Worker        # Check all boundary cases of valid seed value inputs
9149*da0073e9SAndroid Build Coastguard Worker        test_cases = [
9150*da0073e9SAndroid Build Coastguard Worker            # (seed, expected_initial_seed)
9151*da0073e9SAndroid Build Coastguard Worker            # Positive seeds should be unchanged
9152*da0073e9SAndroid Build Coastguard Worker            (max_int64, max_int64),
9153*da0073e9SAndroid Build Coastguard Worker            (max_int64 + 1, max_int64 + 1),
9154*da0073e9SAndroid Build Coastguard Worker            (max_uint64, max_uint64),
9155*da0073e9SAndroid Build Coastguard Worker            (0, 0),
9156*da0073e9SAndroid Build Coastguard Worker            # Negative seeds wrap around starting from the largest seed value
9157*da0073e9SAndroid Build Coastguard Worker            (-1, max_uint64),
9158*da0073e9SAndroid Build Coastguard Worker            (min_int64, max_int64 + 1)
9159*da0073e9SAndroid Build Coastguard Worker        ]
9160*da0073e9SAndroid Build Coastguard Worker        for seed, expected_initial_seed in test_cases:
9161*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(seed)
9162*da0073e9SAndroid Build Coastguard Worker            actual_initial_seed = torch.initial_seed()
9163*da0073e9SAndroid Build Coastguard Worker            msg = (f"expected initial_seed() = {expected_initial_seed:x} "
9164*da0073e9SAndroid Build Coastguard Worker                   f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead")
9165*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
9166*da0073e9SAndroid Build Coastguard Worker        for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
9167*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long'):
9168*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(invalid_seed)
9169*da0073e9SAndroid Build Coastguard Worker
9170*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(rng_state)
9171*da0073e9SAndroid Build Coastguard Worker
9172*da0073e9SAndroid Build Coastguard Worker    # FIXME: Describe this test and port to the generic device framework in a more
9173*da0073e9SAndroid Build Coastguard Worker    #   appropriate test suite for the copy operation
9174*da0073e9SAndroid Build Coastguard Worker    def test_copy_transpose(self):
9175*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(100 * 100, dtype=torch.float).reshape(100, 100).t()
9176*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(100, 100, dtype=torch.float)
9177*da0073e9SAndroid Build Coastguard Worker        y.copy_(x)
9178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 0], range(100))
9179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 40], range(4000, 4100))
9180*da0073e9SAndroid Build Coastguard Worker
9181*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(100, 100, dtype=torch.double)
9182*da0073e9SAndroid Build Coastguard Worker        y.copy_(x)
9183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 0], range(100))
9184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 40], range(4000, 4100))
9185*da0073e9SAndroid Build Coastguard Worker
9186*da0073e9SAndroid Build Coastguard Worker        # Validates regression reported in https://github.com/pytorch/pytorch/issues/45269
9187*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.cfloat).t()
9188*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(100, 100, dtype=torch.cfloat)
9189*da0073e9SAndroid Build Coastguard Worker        y.copy_(x)
9190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 0], range(100))
9191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 40], range(4000, 4100))
9192*da0073e9SAndroid Build Coastguard Worker
9193*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.complex32).t()
9194*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(100, 100, dtype=torch.complex32)
9195*da0073e9SAndroid Build Coastguard Worker        y.copy_(x)
9196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 0], range(100))
9197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[:, 40], range(4000, 4100))
9198*da0073e9SAndroid Build Coastguard Worker
9199*da0073e9SAndroid Build Coastguard Worker    # FIXME: Port to a more appropriate test suite
9200*da0073e9SAndroid Build Coastguard Worker    def test_copy_broadcast(self):
9201*da0073e9SAndroid Build Coastguard Worker        torch.zeros(5, 6).copy_(torch.zeros(6))
9202*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
9203*da0073e9SAndroid Build Coastguard Worker
9204*da0073e9SAndroid Build Coastguard Worker    # FIXME: Port to a more appropriate test suite
9205*da0073e9SAndroid Build Coastguard Worker    # Fails with inductor (and aot_eager) because functionalization replaces copy_ with copy,
9206*da0073e9SAndroid Build Coastguard Worker    # which doesn't properly error on bad inputs.
9207*da0073e9SAndroid Build Coastguard Worker    def test_copy_many_to_one(self):
9208*da0073e9SAndroid Build Coastguard Worker        # Testing in-place copy where it attempt to write from many memory
9209*da0073e9SAndroid Build Coastguard Worker        # storage to a single storage would cause RuntimeError to be thrown
9210*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6)))
9211*da0073e9SAndroid Build Coastguard Worker
9212*da0073e9SAndroid Build Coastguard Worker    def test_copy_float16(self):
9213*da0073e9SAndroid Build Coastguard Worker        # Check that fbgemm code no longer reads memory out of bounds, see
9214*da0073e9SAndroid Build Coastguard Worker        # copy_impl and fbgemm::Float16ToFloat_ref.
9215*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/88543
9216*da0073e9SAndroid Build Coastguard Worker
9217*da0073e9SAndroid Build Coastguard Worker        # Types to test different code paths in copy_impl.
9218*da0073e9SAndroid Build Coastguard Worker        dtypes = (
9219*da0073e9SAndroid Build Coastguard Worker            # out_dtype, src_dtype
9220*da0073e9SAndroid Build Coastguard Worker            (torch.float32, torch.float16),  # fbgemm
9221*da0073e9SAndroid Build Coastguard Worker            (torch.float16, torch.float32),  # fbgemm
9222*da0073e9SAndroid Build Coastguard Worker            (torch.float32, torch.float32),  # TensorIterator
9223*da0073e9SAndroid Build Coastguard Worker        )
9224*da0073e9SAndroid Build Coastguard Worker
9225*da0073e9SAndroid Build Coastguard Worker        cases = (
9226*da0073e9SAndroid Build Coastguard Worker            # out_shape, src_shape, is_ok
9227*da0073e9SAndroid Build Coastguard Worker            # These cases used to crash with fbgemm, make sure these also raise
9228*da0073e9SAndroid Build Coastguard Worker            # exceptions with TensorIterator.
9229*da0073e9SAndroid Build Coastguard Worker            ((1, 2, 3), (0, 2, 3), False),  # same strides, not allowed by TI
9230*da0073e9SAndroid Build Coastguard Worker            ((1, 5, 6), (4, 5, 6), False),  # same strides, not allowed by TI
9231*da0073e9SAndroid Build Coastguard Worker            (1, (0, 2, 3), False),  # different strides
9232*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 6), (0, 2, 3), False),  # different strides
9233*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 6), (1, 2, 3), False),  # different strides
9234*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 6), (6, 5, 4), False),  # same numel
9235*da0073e9SAndroid Build Coastguard Worker
9236*da0073e9SAndroid Build Coastguard Worker            # These cases should pass with fbgemm and TensorIterator.
9237*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 6), (1, 5, 6), True),  # same strides
9238*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 6), (4, 5, 6), True),  # same strides
9239*da0073e9SAndroid Build Coastguard Worker            ((0, 2, 3), 1, True),  # different strides, allowed by TI
9240*da0073e9SAndroid Build Coastguard Worker            ((4, 5, 6), (4, 5, 1), True),  # different strides, allowed by TI
9241*da0073e9SAndroid Build Coastguard Worker        )
9242*da0073e9SAndroid Build Coastguard Worker
9243*da0073e9SAndroid Build Coastguard Worker        for (out_shape, src_shape, is_ok), (out_dtype, src_dtype) in itertools.product(cases, dtypes):
9244*da0073e9SAndroid Build Coastguard Worker            out = torch.zeros(out_shape, dtype=out_dtype, device=torch.device('cpu'))
9245*da0073e9SAndroid Build Coastguard Worker            src = torch.ones(src_shape, dtype=src_dtype, device=torch.device('cpu'))
9246*da0073e9SAndroid Build Coastguard Worker            if is_ok:
9247*da0073e9SAndroid Build Coastguard Worker                if torch.cuda.is_available():
9248*da0073e9SAndroid Build Coastguard Worker                    out_cuda = out.cuda()
9249*da0073e9SAndroid Build Coastguard Worker                    src_cuda = src.cuda()
9250*da0073e9SAndroid Build Coastguard Worker                res = out.copy_(src)
9251*da0073e9SAndroid Build Coastguard Worker                if torch.cuda.is_available():
9252*da0073e9SAndroid Build Coastguard Worker                    res_cuda = out_cuda.copy_(src_cuda)
9253*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res, res_cuda)
9254*da0073e9SAndroid Build Coastguard Worker            else:
9255*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: out.copy_(src))
9256*da0073e9SAndroid Build Coastguard Worker
9257*da0073e9SAndroid Build Coastguard Worker    # FIXME: Port to a more appropriate test suite
9258*da0073e9SAndroid Build Coastguard Worker    def _test_to_with_layout(self, layout):
9259*da0073e9SAndroid Build Coastguard Worker        def test_copy_behavior(t, non_blocking=False):
9260*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, t.to(t, non_blocking=non_blocking))
9261*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
9262*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
9263*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
9264*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
9265*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True))
9266*da0073e9SAndroid Build Coastguard Worker
9267*da0073e9SAndroid Build Coastguard Worker            devices = [t.device]
9268*da0073e9SAndroid Build Coastguard Worker            if t.device.type == 'cuda':
9269*da0073e9SAndroid Build Coastguard Worker                if t.device.index == -1:
9270*da0073e9SAndroid Build Coastguard Worker                    devices.append(f'cuda:{torch.cuda.current_device()}')
9271*da0073e9SAndroid Build Coastguard Worker                elif t.device.index == torch.cuda.current_device():
9272*da0073e9SAndroid Build Coastguard Worker                    devices.append('cuda')
9273*da0073e9SAndroid Build Coastguard Worker            for device in devices:
9274*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t, t.to(device, non_blocking=non_blocking))
9275*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
9276*da0073e9SAndroid Build Coastguard Worker                self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
9277*da0073e9SAndroid Build Coastguard Worker                self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
9278*da0073e9SAndroid Build Coastguard Worker
9279*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(5)
9280*da0073e9SAndroid Build Coastguard Worker        if layout == torch.sparse_csr:
9281*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr()
9282*da0073e9SAndroid Build Coastguard Worker        test_copy_behavior(a)
9283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.device, a.to('cpu').device)
9284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device)
9285*da0073e9SAndroid Build Coastguard Worker        self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype)
9286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.device, a.to(torch.float32).device)
9287*da0073e9SAndroid Build Coastguard Worker        self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
9288*da0073e9SAndroid Build Coastguard Worker
9289*da0073e9SAndroid Build Coastguard Worker        def test_data_ptr(getter):
9290*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getter(a), getter(a.to('cpu')))
9291*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False)))
9292*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getter(a), getter(a.to('cpu', copy=False)))
9293*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True)))
9294*da0073e9SAndroid Build Coastguard Worker        if layout == torch.sparse_csr:
9295*da0073e9SAndroid Build Coastguard Worker            # TODO: compressed sparse tensors currently don't support data_ptr.
9296*da0073e9SAndroid Build Coastguard Worker            # Exercising failure will allow us to widen coverage of this test once it does.
9297*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"):
9298*da0073e9SAndroid Build Coastguard Worker                a.data_ptr()
9299*da0073e9SAndroid Build Coastguard Worker            # While compressed sparse tensors don't have a concept of data_ptr
9300*da0073e9SAndroid Build Coastguard Worker            # the underlying tensors do. The implementation of to appropriately forwards
9301*da0073e9SAndroid Build Coastguard Worker            # the call to the components, which is what we're test here.
9302*da0073e9SAndroid Build Coastguard Worker            test_data_ptr(lambda a: a.values().data_ptr())
9303*da0073e9SAndroid Build Coastguard Worker            test_data_ptr(lambda a: a.crow_indices().data_ptr())
9304*da0073e9SAndroid Build Coastguard Worker            test_data_ptr(lambda a: a.col_indices().data_ptr())
9305*da0073e9SAndroid Build Coastguard Worker        else:
9306*da0073e9SAndroid Build Coastguard Worker            test_data_ptr(lambda a: a.data_ptr())
9307*da0073e9SAndroid Build Coastguard Worker
9308*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
9309*da0073e9SAndroid Build Coastguard Worker            for non_blocking in [True, False]:
9310*da0073e9SAndroid Build Coastguard Worker                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
9311*da0073e9SAndroid Build Coastguard Worker                    b = torch.tensor(5., device=cuda)
9312*da0073e9SAndroid Build Coastguard Worker                    test_copy_behavior(b, non_blocking)
9313*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
9314*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device)
9315*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
9316*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
9317*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device)
9318*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
9319*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, b.to(dtype=torch.int32).device)
9320*da0073e9SAndroid Build Coastguard Worker
9321*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
9322*da0073e9SAndroid Build Coastguard Worker        self._test_to_with_layout(torch.strided)
9323*da0073e9SAndroid Build Coastguard Worker        is_cuda10_2_or_higher = (
9324*da0073e9SAndroid Build Coastguard Worker            (torch.version.cuda is not None)
9325*da0073e9SAndroid Build Coastguard Worker            and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
9326*da0073e9SAndroid Build Coastguard Worker        if is_cuda10_2_or_higher:  # in cuda10_1 sparse_csr is beta
9327*da0073e9SAndroid Build Coastguard Worker            self._test_to_with_layout(torch.sparse_csr)
9328*da0073e9SAndroid Build Coastguard Worker
9329*da0073e9SAndroid Build Coastguard Worker    # FIXME: describe this test
9330*da0073e9SAndroid Build Coastguard Worker    def test_as_subclass(self):
9331*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
9332*da0073e9SAndroid Build Coastguard Worker            member_var = object()
9333*da0073e9SAndroid Build Coastguard Worker
9334*da0073e9SAndroid Build Coastguard Worker        t0 = torch.tensor(0)
9335*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([1, 2])
9336*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([[3, 4], [5, 6]])
9337*da0073e9SAndroid Build Coastguard Worker
9338*da0073e9SAndroid Build Coastguard Worker        s0 = t0.as_subclass(SubTensor)
9339*da0073e9SAndroid Build Coastguard Worker        s1 = t1.as_subclass(SubTensor)
9340*da0073e9SAndroid Build Coastguard Worker        s2 = t2.as_subclass(SubTensor)
9341*da0073e9SAndroid Build Coastguard Worker
9342*da0073e9SAndroid Build Coastguard Worker        # Check that the correct type is returned.
9343*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type(s0) is SubTensor)
9344*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type(s1) is SubTensor)
9345*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type(s2) is SubTensor)
9346*da0073e9SAndroid Build Coastguard Worker
9347*da0073e9SAndroid Build Coastguard Worker        # Check that the data is equal.
9348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t0, s0)
9349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, s1)
9350*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2, s2)
9351*da0073e9SAndroid Build Coastguard Worker
9352*da0073e9SAndroid Build Coastguard Worker        t0[()] = 1
9353*da0073e9SAndroid Build Coastguard Worker        t1[1] = 3
9354*da0073e9SAndroid Build Coastguard Worker        t2[1, 1] = 7
9355*da0073e9SAndroid Build Coastguard Worker
9356*da0073e9SAndroid Build Coastguard Worker        # Check that the data is equal even after modification.
9357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t0, s0)
9358*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, s1)
9359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2, s2)
9360*da0073e9SAndroid Build Coastguard Worker
9361*da0073e9SAndroid Build Coastguard Worker        # Check that member variables are passed through.
9362*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0.member_var is SubTensor.member_var)
9363*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s1.member_var is SubTensor.member_var)
9364*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s2.member_var is SubTensor.member_var)
9365*da0073e9SAndroid Build Coastguard Worker
9366*da0073e9SAndroid Build Coastguard Worker        # Test that autograd is propagated.
9367*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(5, dtype=torch.float32, requires_grad=True)
9368*da0073e9SAndroid Build Coastguard Worker
9369*da0073e9SAndroid Build Coastguard Worker        # Run a calculation on the tensor.
9370*da0073e9SAndroid Build Coastguard Worker        exp_t = torch.exp(t)
9371*da0073e9SAndroid Build Coastguard Worker
9372*da0073e9SAndroid Build Coastguard Worker        # Cast exp_t to a subclass.
9373*da0073e9SAndroid Build Coastguard Worker        exp_s = exp_t.as_subclass(SubTensor)
9374*da0073e9SAndroid Build Coastguard Worker
9375*da0073e9SAndroid Build Coastguard Worker        # Make sure that t.grad was initially None
9376*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.grad is None)
9377*da0073e9SAndroid Build Coastguard Worker
9378*da0073e9SAndroid Build Coastguard Worker        # Run the autograd calculation.
9379*da0073e9SAndroid Build Coastguard Worker        exp_s.backward()
9380*da0073e9SAndroid Build Coastguard Worker
9381*da0073e9SAndroid Build Coastguard Worker        # Make sure autograd was propagated to the original tensor
9382*da0073e9SAndroid Build Coastguard Worker        # declared with requires_grad.
9383*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.grad is not None)
9384*da0073e9SAndroid Build Coastguard Worker
9385*da0073e9SAndroid Build Coastguard Worker        # Make sure invalid subclasses raise nice errors
9386*da0073e9SAndroid Build Coastguard Worker        class BadSubTensor:
9387*da0073e9SAndroid Build Coastguard Worker            member_var = object()
9388*da0073e9SAndroid Build Coastguard Worker
9389*da0073e9SAndroid Build Coastguard Worker        err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
9390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
9391*da0073e9SAndroid Build Coastguard Worker            s0 = t0.as_subclass(BadSubTensor)
9392*da0073e9SAndroid Build Coastguard Worker
9393*da0073e9SAndroid Build Coastguard Worker    # FIXME: Port to a test suite that better fits slicing
9394*da0073e9SAndroid Build Coastguard Worker    def test_slice(self):
9395*da0073e9SAndroid Build Coastguard Worker        empty = torch.empty(0, 4)
9396*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0., 16).view(4, 4)
9397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:], x)
9398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:4], x)
9399*da0073e9SAndroid Build Coastguard Worker        # start and stop are clamped to the size of dim
9400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:5], x)
9401*da0073e9SAndroid Build Coastguard Worker        # if start >= stop then the result is empty
9402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[2:1], empty)
9403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[2:2], empty)
9404*da0073e9SAndroid Build Coastguard Worker        # out of bounds is also empty
9405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[10:12], empty)
9406*da0073e9SAndroid Build Coastguard Worker        # additional correctness checks
9407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:1].tolist(), [[0, 1, 2, 3]])
9408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:-3].tolist(), [[0, 1, 2, 3]])
9409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, -2:3].tolist(), [[2], [6], [10], [14]])
9410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[0:-1:2].tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
9411*da0073e9SAndroid Build Coastguard Worker
9412*da0073e9SAndroid Build Coastguard Worker    def test_split_with_sizes_copy_out(self):
9413*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
9414*da0073e9SAndroid Build Coastguard Worker        shape = (30, 40, 50)
9415*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(*shape, device=device)
9416*da0073e9SAndroid Build Coastguard Worker        cases = [
9417*da0073e9SAndroid Build Coastguard Worker            (0, [3, 7, 8, 12]),
9418*da0073e9SAndroid Build Coastguard Worker            (1, [3, 7, 10, 20]),
9419*da0073e9SAndroid Build Coastguard Worker            (-2, [3, 7, 10, 20]),
9420*da0073e9SAndroid Build Coastguard Worker            (2, [3, 7, 10, 12, 18]),
9421*da0073e9SAndroid Build Coastguard Worker            (-1, [3, 7, 10, 12, 18]),
9422*da0073e9SAndroid Build Coastguard Worker            (2, [3, 7, 10, 0, 30]),
9423*da0073e9SAndroid Build Coastguard Worker        ]
9424*da0073e9SAndroid Build Coastguard Worker        for dim, split_sizes in cases:
9425*da0073e9SAndroid Build Coastguard Worker            views = x.split_with_sizes(split_sizes, dim=dim)
9426*da0073e9SAndroid Build Coastguard Worker            expects = [v.clone() for v in views]
9427*da0073e9SAndroid Build Coastguard Worker            out = [torch.zeros_like(v) for v in views]
9428*da0073e9SAndroid Build Coastguard Worker            for expect, t in zip(expects, out):
9429*da0073e9SAndroid Build Coastguard Worker                if expect.numel() != 0:
9430*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(expect.eq(t).all().item())
9431*da0073e9SAndroid Build Coastguard Worker
9432*da0073e9SAndroid Build Coastguard Worker            torch.split_with_sizes_copy(x, split_sizes, dim=dim, out=out)
9433*da0073e9SAndroid Build Coastguard Worker            for expect, t in zip(expects, out):
9434*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(expect.eq(t).all().item())
9435*da0073e9SAndroid Build Coastguard Worker
9436*da0073e9SAndroid Build Coastguard Worker            if not torch.cuda.is_available():
9437*da0073e9SAndroid Build Coastguard Worker                continue
9438*da0073e9SAndroid Build Coastguard Worker
9439*da0073e9SAndroid Build Coastguard Worker            # Test with cuda graph
9440*da0073e9SAndroid Build Coastguard Worker            out = [torch.zeros_like(v) for v in views]
9441*da0073e9SAndroid Build Coastguard Worker            for expect, t in zip(expects, out):
9442*da0073e9SAndroid Build Coastguard Worker                if expect.numel() != 0:
9443*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(expect.eq(t).all().item())
9444*da0073e9SAndroid Build Coastguard Worker
9445*da0073e9SAndroid Build Coastguard Worker            g = torch.cuda.CUDAGraph()
9446*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.graph(g):
9447*da0073e9SAndroid Build Coastguard Worker                torch.split_with_sizes_copy(x, split_sizes, dim=dim, out=out)
9448*da0073e9SAndroid Build Coastguard Worker
9449*da0073e9SAndroid Build Coastguard Worker            g.replay()
9450*da0073e9SAndroid Build Coastguard Worker            for expect, t in zip(expects, out):
9451*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(expect.eq(t).all().item())
9452*da0073e9SAndroid Build Coastguard Worker
9453*da0073e9SAndroid Build Coastguard Worker    def test_type(self):
9454*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3).double()
9455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
9456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32)
9457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype())
9458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.type(torch.int32).dtype, torch.int32)
9459*da0073e9SAndroid Build Coastguard Worker
9460*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to a quantization test suite
9461*da0073e9SAndroid Build Coastguard Worker    def test_qengine(self):
9462*da0073e9SAndroid Build Coastguard Worker        qengines = torch.backends.quantized.supported_engines
9463*da0073e9SAndroid Build Coastguard Worker        original_qe = torch.backends.quantized.engine
9464*da0073e9SAndroid Build Coastguard Worker        for qe in qengines:
9465*da0073e9SAndroid Build Coastguard Worker            torch.backends.quantized.engine = qe
9466*da0073e9SAndroid Build Coastguard Worker            assert torch.backends.quantized.engine == qe, 'qengine not set successfully'
9467*da0073e9SAndroid Build Coastguard Worker        torch.backends.quantized.engine = original_qe
9468*da0073e9SAndroid Build Coastguard Worker
9469*da0073e9SAndroid Build Coastguard Worker    def test_terminate_handler_on_crash(self):
9470*da0073e9SAndroid Build Coastguard Worker        cmd = [sys.executable, '-c', "import os; os.environ[\"TORCH_CUSTOM_TERMINATE\"] ='1'; \
9471*da0073e9SAndroid Build Coastguard Worker               import torch; import torch._C; torch._C._abort()"]
9472*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(subprocess.CalledProcessError) as cm:
9473*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(cmd, shell=False)
9474*da0073e9SAndroid Build Coastguard Worker        e = cm.exception
9475*da0073e9SAndroid Build Coastguard Worker        output = e.stdout.decode("utf-8")
9476*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(e.returncode, 0)
9477*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(output, None)
9478*da0073e9SAndroid Build Coastguard Worker        self.assertIn('Unhandled exception caught in c10/util/AbortHandler.h', output)
9479*da0073e9SAndroid Build Coastguard Worker
9480*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to a distributed test suite -- also... how could this be OOMing on Windows CUDA?
9481*da0073e9SAndroid Build Coastguard Worker    @slowTest
9482*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
9483*da0073e9SAndroid Build Coastguard Worker                        don't support multiprocessing with spawn start method")
9484*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows')
9485*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_invalid_probs(self):
9486*da0073e9SAndroid Build Coastguard Worker        def _spawn_method(self, method, arg):
9487*da0073e9SAndroid Build Coastguard Worker            try:
9488*da0073e9SAndroid Build Coastguard Worker                mp.set_start_method('spawn')
9489*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
9490*da0073e9SAndroid Build Coastguard Worker                pass
9491*da0073e9SAndroid Build Coastguard Worker            with mp.Pool(1) as pool:
9492*da0073e9SAndroid Build Coastguard Worker                out = pool.map(method, [arg])
9493*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out[0])
9494*da0073e9SAndroid Build Coastguard Worker
9495*da0073e9SAndroid Build Coastguard Worker        def _test_multinomial_invalid_probs(probs):
9496*da0073e9SAndroid Build Coastguard Worker            try:
9497*da0073e9SAndroid Build Coastguard Worker                # n_sample = 1 is a special case, test n_sample=2 which is more general
9498*da0073e9SAndroid Build Coastguard Worker                torch.multinomial(probs.to('cpu'), 2)
9499*da0073e9SAndroid Build Coastguard Worker                return False  # Should not be reached
9500*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
9501*da0073e9SAndroid Build Coastguard Worker                return 'probability tensor contains either `inf`, `nan` or element < 0' in str(e)
9502*da0073e9SAndroid Build Coastguard Worker
9503*da0073e9SAndroid Build Coastguard Worker            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -1., 1.]))
9504*da0073e9SAndroid Build Coastguard Worker            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., inf, 1.]))
9505*da0073e9SAndroid Build Coastguard Worker            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -inf, 1.]))
9506*da0073e9SAndroid Build Coastguard Worker            _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., 1., nan]))
9507*da0073e9SAndroid Build Coastguard Worker
9508*da0073e9SAndroid Build Coastguard Worker    # FIXME: port to more appropriate test suite
9509*da0073e9SAndroid Build Coastguard Worker    def test_to_with_tensor(self):
9510*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(5)
9511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.device, a.to(a).device)
9512*da0073e9SAndroid Build Coastguard Worker
9513*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
9514*da0073e9SAndroid Build Coastguard Worker            for non_blocking in [True, False]:
9515*da0073e9SAndroid Build Coastguard Worker                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
9516*da0073e9SAndroid Build Coastguard Worker                    b = torch.tensor(5., device=cuda)
9517*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
9518*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
9519*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
9520*da0073e9SAndroid Build Coastguard Worker
9521*da0073e9SAndroid Build Coastguard Worker    def test_device(self):
9522*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device('cpu')
9523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cpu', str(cpu))
9524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cpu', cpu.type)
9525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, cpu.index)
9526*da0073e9SAndroid Build Coastguard Worker
9527*da0073e9SAndroid Build Coastguard Worker        cpu0 = torch.device('cpu:0')
9528*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cpu:0', str(cpu0))
9529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cpu', cpu0.type)
9530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, cpu0.index)
9531*da0073e9SAndroid Build Coastguard Worker
9532*da0073e9SAndroid Build Coastguard Worker        cpu0 = torch.device('cpu', 0)
9533*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cpu:0', str(cpu0))
9534*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cpu', cpu0.type)
9535*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, cpu0.index)
9536*da0073e9SAndroid Build Coastguard Worker
9537*da0073e9SAndroid Build Coastguard Worker        cuda = torch.device('cuda')
9538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda', str(cuda))
9539*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda', cuda.type)
9540*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, cuda.index)
9541*da0073e9SAndroid Build Coastguard Worker
9542*da0073e9SAndroid Build Coastguard Worker        cuda1 = torch.device('cuda:1')
9543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda:1', str(cuda1))
9544*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda', cuda1.type)
9545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, cuda1.index)
9546*da0073e9SAndroid Build Coastguard Worker
9547*da0073e9SAndroid Build Coastguard Worker        cuda1 = torch.device('cuda', 1)
9548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda:1', str(cuda1))
9549*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda', cuda1.type)
9550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, cuda1.index)
9551*da0073e9SAndroid Build Coastguard Worker
9552*da0073e9SAndroid Build Coastguard Worker        cuda90 = torch.device('cuda', 90)
9553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda:90', str(cuda90))
9554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual('cuda', cuda90.type)
9555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(90, cuda90.index)
9556*da0073e9SAndroid Build Coastguard Worker
9557*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1'))
9558*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1'))
9559*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 '))
9560*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2'))
9561*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 2'))
9562*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.'))
9563*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2?'))
9564*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:?2'))
9565*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:'))
9566*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.232'))
9567*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3'))
9568*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3'))
9569*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3'))
9570*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device(-1))
9571*da0073e9SAndroid Build Coastguard Worker
9572*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('other'))
9573*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.device('other:0'))
9574*da0073e9SAndroid Build Coastguard Worker
9575*da0073e9SAndroid Build Coastguard Worker        device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
9576*da0073e9SAndroid Build Coastguard Worker        device_hash_set = set()
9577*da0073e9SAndroid Build Coastguard Worker        device_hash_set.update(hash(torch.device(device)) for device in device_set)
9578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(device_set), len(device_hash_set))
9579*da0073e9SAndroid Build Coastguard Worker
9580*da0073e9SAndroid Build Coastguard Worker        def get_expected_device_repr(device):
9581*da0073e9SAndroid Build Coastguard Worker            if device.index is not None:
9582*da0073e9SAndroid Build Coastguard Worker                return f"device(type='{device.type}', index={device.index})"
9583*da0073e9SAndroid Build Coastguard Worker
9584*da0073e9SAndroid Build Coastguard Worker            return f"device(type='{device.type}')"
9585*da0073e9SAndroid Build Coastguard Worker
9586*da0073e9SAndroid Build Coastguard Worker        for device in device_set:
9587*da0073e9SAndroid Build Coastguard Worker            dev = torch.device(device)
9588*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(repr(dev), get_expected_device_repr(dev))
9589*da0073e9SAndroid Build Coastguard Worker
9590*da0073e9SAndroid Build Coastguard Worker    # Tests that the use_deterministic_flag can be set as expected
9591*da0073e9SAndroid Build Coastguard Worker    @wrapDeterministicFlagAPITest
9592*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_flag(self):
9593*da0073e9SAndroid Build Coastguard Worker        for deterministic, warn_only in product([True, False], [True, False]):
9594*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
9595*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
9596*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())
9597*da0073e9SAndroid Build Coastguard Worker
9598*da0073e9SAndroid Build Coastguard Worker            if deterministic:
9599*da0073e9SAndroid Build Coastguard Worker                if warn_only:
9600*da0073e9SAndroid Build Coastguard Worker                    debug_mode = 1
9601*da0073e9SAndroid Build Coastguard Worker                else:
9602*da0073e9SAndroid Build Coastguard Worker                    debug_mode = 2
9603*da0073e9SAndroid Build Coastguard Worker            else:
9604*da0073e9SAndroid Build Coastguard Worker                debug_mode = 0
9605*da0073e9SAndroid Build Coastguard Worker
9606*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
9607*da0073e9SAndroid Build Coastguard Worker
9608*da0073e9SAndroid Build Coastguard Worker        for debug_mode in [0, 1, 2]:
9609*da0073e9SAndroid Build Coastguard Worker            torch.set_deterministic_debug_mode(debug_mode)
9610*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
9611*da0073e9SAndroid Build Coastguard Worker            deterministic = debug_mode in [1, 2]
9612*da0073e9SAndroid Build Coastguard Worker            warn_only = debug_mode == 1
9613*da0073e9SAndroid Build Coastguard Worker
9614*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
9615*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())
9616*da0073e9SAndroid Build Coastguard Worker
9617*da0073e9SAndroid Build Coastguard Worker        for debug_mode, debug_mode_str in [(0, 'default'), (1, 'warn'), (2, 'error')]:
9618*da0073e9SAndroid Build Coastguard Worker            torch.set_deterministic_debug_mode(debug_mode_str)
9619*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
9620*da0073e9SAndroid Build Coastguard Worker
9621*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
9622*da0073e9SAndroid Build Coastguard Worker                TypeError,
9623*da0073e9SAndroid Build Coastguard Worker                r"_set_deterministic_algorithms\(\): argument 'mode' \(position 1\) must be bool, not int"):
9624*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(1)
9625*da0073e9SAndroid Build Coastguard Worker
9626*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
9627*da0073e9SAndroid Build Coastguard Worker                TypeError,
9628*da0073e9SAndroid Build Coastguard Worker                r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
9629*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(False, warn_only=1)
9630*da0073e9SAndroid Build Coastguard Worker
9631*da0073e9SAndroid Build Coastguard Worker    # Tests that torch.utils.deterministic.fill_uninitialized_memory can be set as expected
9632*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_fill_uninitialized_memory(self):
9633*da0073e9SAndroid Build Coastguard Worker        with DeterministicGuard(True, fill_uninitialized_memory=False):
9634*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9635*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9636*da0073e9SAndroid Build Coastguard Worker
9637*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(True, fill_uninitialized_memory=True):
9638*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
9639*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
9640*da0073e9SAndroid Build Coastguard Worker
9641*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9642*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9643*da0073e9SAndroid Build Coastguard Worker
9644*da0073e9SAndroid Build Coastguard Worker            torch.utils.deterministic.fill_uninitialized_memory = False
9645*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9646*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9647*da0073e9SAndroid Build Coastguard Worker
9648*da0073e9SAndroid Build Coastguard Worker            torch.utils.deterministic.fill_uninitialized_memory = True
9649*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
9650*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
9651*da0073e9SAndroid Build Coastguard Worker
9652*da0073e9SAndroid Build Coastguard Worker            torch._C._set_deterministic_fill_uninitialized_memory(False)
9653*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
9654*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
9655*da0073e9SAndroid Build Coastguard Worker
9656*da0073e9SAndroid Build Coastguard Worker            torch._C._set_deterministic_fill_uninitialized_memory(True)
9657*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
9658*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
9659*da0073e9SAndroid Build Coastguard Worker
9660*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"expected a bool, but got int"):
9661*da0073e9SAndroid Build Coastguard Worker                torch.utils.deterministic.fill_uninitialized_memory = 1
9662*da0073e9SAndroid Build Coastguard Worker
9663*da0073e9SAndroid Build Coastguard Worker    def test_type_conversion_via_dtype_name(self):
9664*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1])
9665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.byte().dtype, torch.uint8)
9666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.bool().dtype, torch.bool)
9667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.char().dtype, torch.int8)
9668*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.double().dtype, torch.float64)
9669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.float().dtype, torch.float32)
9670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.half().dtype, torch.float16)
9671*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.int().dtype, torch.int32)
9672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
9673*da0073e9SAndroid Build Coastguard Worker        cfloat = x.cfloat()
9674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cfloat.dtype, torch.complex64)
9675*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cfloat.real, x.float())
9676*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cfloat.imag, torch.zeros_like(cfloat.imag))
9677*da0073e9SAndroid Build Coastguard Worker        cdouble = x.cdouble()
9678*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cdouble.dtype, torch.complex128)
9679*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cdouble.real, x.double())
9680*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cdouble.imag, torch.zeros_like(cdouble.imag))
9681*da0073e9SAndroid Build Coastguard Worker        chalf = x.chalf()
9682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chalf.dtype, torch.complex32)
9683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chalf.real, x.half())
9684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chalf.imag, torch.zeros_like(chalf.imag))
9685*da0073e9SAndroid Build Coastguard Worker
9686*da0073e9SAndroid Build Coastguard Worker    def test_type_alias(self):
9687*da0073e9SAndroid Build Coastguard Worker        type_alias_map = {torch.float64: torch.double,
9688*da0073e9SAndroid Build Coastguard Worker                          torch.float32: torch.float,
9689*da0073e9SAndroid Build Coastguard Worker                          torch.int32: torch.int,
9690*da0073e9SAndroid Build Coastguard Worker                          torch.int64: torch.long,
9691*da0073e9SAndroid Build Coastguard Worker                          torch.int16: torch.short,
9692*da0073e9SAndroid Build Coastguard Worker                          torch.float16: torch.half,
9693*da0073e9SAndroid Build Coastguard Worker                          torch.complex32: torch.chalf,
9694*da0073e9SAndroid Build Coastguard Worker                          torch.complex64: torch.cfloat}
9695*da0073e9SAndroid Build Coastguard Worker        for dtype, alias in type_alias_map.items():
9696*da0073e9SAndroid Build Coastguard Worker            self.assertIs(alias, dtype)
9697*da0073e9SAndroid Build Coastguard Worker
9698*da0073e9SAndroid Build Coastguard Worker    def test_doc_template(self) -> None:
9699*da0073e9SAndroid Build Coastguard Worker        """
9700*da0073e9SAndroid Build Coastguard Worker        Test that all public API doc strings use the same standard template for
9701*da0073e9SAndroid Build Coastguard Worker        all common arguments such as tensor or dim
9702*da0073e9SAndroid Build Coastguard Worker        """
9703*da0073e9SAndroid Build Coastguard Worker        from torch._torch_docs import __file__ as doc_file
9704*da0073e9SAndroid Build Coastguard Worker        from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args
9705*da0073e9SAndroid Build Coastguard Worker
9706*da0073e9SAndroid Build Coastguard Worker        with open(doc_file, encoding="utf-8") as f:
9707*da0073e9SAndroid Build Coastguard Worker            doc_strs = f.read()
9708*da0073e9SAndroid Build Coastguard Worker
9709*da0073e9SAndroid Build Coastguard Worker        matches = re.findall(
9710*da0073e9SAndroid Build Coastguard Worker            r'add_docstr\(([^,]+?),[^"\']*?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')(?:\.|,?[^,\)]*?\))',
9711*da0073e9SAndroid Build Coastguard Worker            doc_strs,
9712*da0073e9SAndroid Build Coastguard Worker            re.MULTILINE | re.DOTALL,
9713*da0073e9SAndroid Build Coastguard Worker        )
9714*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(matches)
9715*da0073e9SAndroid Build Coastguard Worker
9716*da0073e9SAndroid Build Coastguard Worker        for m in matches:
9717*da0073e9SAndroid Build Coastguard Worker            func = m[0].strip()
9718*da0073e9SAndroid Build Coastguard Worker            desc = m[1].strip()
9719*da0073e9SAndroid Build Coastguard Worker
9720*da0073e9SAndroid Build Coastguard Worker            for common_args in [multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args]:
9721*da0073e9SAndroid Build Coastguard Worker                for k, v in common_args.items():
9722*da0073e9SAndroid Build Coastguard Worker                    self.assertNotIn(v, desc, f'The argument description "{v}" in {func} can be '
9723*da0073e9SAndroid Build Coastguard Worker                                              f'replaced by {{{k}}}')
9724*da0073e9SAndroid Build Coastguard Worker
9725*da0073e9SAndroid Build Coastguard Worker    def test_doc(self):
9726*da0073e9SAndroid Build Coastguard Worker        checked_types = (types.MethodType, types.FunctionType,
9727*da0073e9SAndroid Build Coastguard Worker                         types.BuiltinFunctionType, types.BuiltinMethodType)
9728*da0073e9SAndroid Build Coastguard Worker
9729*da0073e9SAndroid Build Coastguard Worker        def _test_namespace(ns, *skips):
9730*da0073e9SAndroid Build Coastguard Worker            if isinstance(ns, object):
9731*da0073e9SAndroid Build Coastguard Worker                ns_name = ns.__class__.__name__
9732*da0073e9SAndroid Build Coastguard Worker            else:
9733*da0073e9SAndroid Build Coastguard Worker                ns_name = ns.__name__
9734*da0073e9SAndroid Build Coastguard Worker            skip_regexes = []
9735*da0073e9SAndroid Build Coastguard Worker            for r in skips:
9736*da0073e9SAndroid Build Coastguard Worker                if isinstance(r, str):
9737*da0073e9SAndroid Build Coastguard Worker                    skip_regexes.append(re.compile(f'^{re.escape(r)}$'))
9738*da0073e9SAndroid Build Coastguard Worker                else:
9739*da0073e9SAndroid Build Coastguard Worker                    skip_regexes.append(r)
9740*da0073e9SAndroid Build Coastguard Worker
9741*da0073e9SAndroid Build Coastguard Worker            for name in dir(ns):
9742*da0073e9SAndroid Build Coastguard Worker                if name.startswith('_'):
9743*da0073e9SAndroid Build Coastguard Worker                    continue
9744*da0073e9SAndroid Build Coastguard Worker                if name in ['real', 'imag']:
9745*da0073e9SAndroid Build Coastguard Worker                    y = torch.randn(1, dtype=torch.cfloat)
9746*da0073e9SAndroid Build Coastguard Worker                    var = getattr(y, name)
9747*da0073e9SAndroid Build Coastguard Worker                elif name in ["H", "mT", "mH"]:
9748*da0073e9SAndroid Build Coastguard Worker                    y = torch.randn(1, 1)
9749*da0073e9SAndroid Build Coastguard Worker                    var = getattr(y, name)
9750*da0073e9SAndroid Build Coastguard Worker                else:
9751*da0073e9SAndroid Build Coastguard Worker                    var = getattr(ns, name)
9752*da0073e9SAndroid Build Coastguard Worker                if not isinstance(var, checked_types):
9753*da0073e9SAndroid Build Coastguard Worker                    continue
9754*da0073e9SAndroid Build Coastguard Worker                doc = var.__doc__
9755*da0073e9SAndroid Build Coastguard Worker                has_doc = doc is not None and len(doc.strip()) > 0
9756*da0073e9SAndroid Build Coastguard Worker                full_name = ns_name + '.' + name
9757*da0073e9SAndroid Build Coastguard Worker                if any(r.match(name) for r in skip_regexes):
9758*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(has_doc,
9759*da0073e9SAndroid Build Coastguard Worker                                     f'New docs have been added for {full_name}, please remove '
9760*da0073e9SAndroid Build Coastguard Worker                                     'it from the skipped list in TestTorch.test_doc')
9761*da0073e9SAndroid Build Coastguard Worker                else:
9762*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(has_doc, f'{full_name} is missing documentation')
9763*da0073e9SAndroid Build Coastguard Worker
9764*da0073e9SAndroid Build Coastguard Worker            # FIXME: All of the following should be marked as expected failures
9765*da0073e9SAndroid Build Coastguard Worker            # so that it is easier to tell when missing has been added.
9766*da0073e9SAndroid Build Coastguard Worker            # FIXME: fix all the skipped ones below!
9767*da0073e9SAndroid Build Coastguard Worker            test_namespace(torch.randn(1),  # noqa: F821
9768*da0073e9SAndroid Build Coastguard Worker                           'as_strided_',
9769*da0073e9SAndroid Build Coastguard Worker                           re.compile('^clamp_(min|max)_?$'),
9770*da0073e9SAndroid Build Coastguard Worker                           'is_distributed',
9771*da0073e9SAndroid Build Coastguard Worker                           'is_nonzero',
9772*da0073e9SAndroid Build Coastguard Worker                           'is_same_size',
9773*da0073e9SAndroid Build Coastguard Worker                           'log_softmax',
9774*da0073e9SAndroid Build Coastguard Worker                           'map2_',
9775*da0073e9SAndroid Build Coastguard Worker                           'new',
9776*da0073e9SAndroid Build Coastguard Worker                           'reinforce',
9777*da0073e9SAndroid Build Coastguard Worker                           'relu',
9778*da0073e9SAndroid Build Coastguard Worker                           'relu_',
9779*da0073e9SAndroid Build Coastguard Worker                           'prelu',
9780*da0073e9SAndroid Build Coastguard Worker                           'resize',
9781*da0073e9SAndroid Build Coastguard Worker                           'resize_as',
9782*da0073e9SAndroid Build Coastguard Worker                           'softmax',
9783*da0073e9SAndroid Build Coastguard Worker                           'split_with_sizes',
9784*da0073e9SAndroid Build Coastguard Worker                           'unsafe_split_with_sizes',
9785*da0073e9SAndroid Build Coastguard Worker                           '_autocast_to_fp16',
9786*da0073e9SAndroid Build Coastguard Worker                           '_autocast_to_fp32',
9787*da0073e9SAndroid Build Coastguard Worker                           )
9788*da0073e9SAndroid Build Coastguard Worker
9789*da0073e9SAndroid Build Coastguard Worker            test_namespace(torch.nn)  # noqa: F821
9790*da0073e9SAndroid Build Coastguard Worker            test_namespace(torch.nn.functional, 'assert_int_or_pair')  # noqa: F821
9791*da0073e9SAndroid Build Coastguard Worker            # TODO: add torch.* tests when we have proper namespacing on ATen functions
9792*da0073e9SAndroid Build Coastguard Worker            # test_namespace(torch)
9793*da0073e9SAndroid Build Coastguard Worker
9794*da0073e9SAndroid Build Coastguard Worker    # FIXME: deprecate torch.Tensor constructor
9795*da0073e9SAndroid Build Coastguard Worker    def test_tensor_ctor_scalar(self):
9796*da0073e9SAndroid Build Coastguard Worker        x = torch.Tensor(torch.tensor(1.0))
9797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.tensor(1.0))
9798*da0073e9SAndroid Build Coastguard Worker
9799*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_gradient(self):
9800*da0073e9SAndroid Build Coastguard Worker        from copy import deepcopy
9801*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(10)
9802*da0073e9SAndroid Build Coastguard Worker        a.grad = torch.ones(10)
9803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, deepcopy(a).grad)
9804*da0073e9SAndroid Build Coastguard Worker        s = torch.zeros(10).to_sparse()
9805*da0073e9SAndroid Build Coastguard Worker        s.grad = torch.ones(10).to_sparse()
9806*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.grad, deepcopy(s).grad)
9807*da0073e9SAndroid Build Coastguard Worker
9808*da0073e9SAndroid Build Coastguard Worker        # ensure sharing is not broken
9809*da0073e9SAndroid Build Coastguard Worker        c = deepcopy([a, a.grad])
9810*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(c[0].grad is c[1])
9811*da0073e9SAndroid Build Coastguard Worker
9812*da0073e9SAndroid Build Coastguard Worker    def test_tensor_base_init(self):
9813*da0073e9SAndroid Build Coastguard Worker        # Direct construction not OK
9814*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._C.TensorBase())
9815*da0073e9SAndroid Build Coastguard Worker
9816*da0073e9SAndroid Build Coastguard Worker        # Subclassing it directly no OK
9817*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot subclass"):
9818*da0073e9SAndroid Build Coastguard Worker            class Tfail(torch._C.TensorBase):
9819*da0073e9SAndroid Build Coastguard Worker                pass
9820*da0073e9SAndroid Build Coastguard Worker
9821*da0073e9SAndroid Build Coastguard Worker        # Doing so with Tensor is ok though
9822*da0073e9SAndroid Build Coastguard Worker        class T(torch.Tensor):
9823*da0073e9SAndroid Build Coastguard Worker            pass
9824*da0073e9SAndroid Build Coastguard Worker
9825*da0073e9SAndroid Build Coastguard Worker        T()
9826*da0073e9SAndroid Build Coastguard Worker
9827*da0073e9SAndroid Build Coastguard Worker    def test_storage_base_init(self):
9828*da0073e9SAndroid Build Coastguard Worker        # Direct construction not OK
9829*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._C.StorageBase())
9830*da0073e9SAndroid Build Coastguard Worker
9831*da0073e9SAndroid Build Coastguard Worker        # But construction of subclass is OK
9832*da0073e9SAndroid Build Coastguard Worker        class T(torch._C.StorageBase):
9833*da0073e9SAndroid Build Coastguard Worker            pass
9834*da0073e9SAndroid Build Coastguard Worker
9835*da0073e9SAndroid Build Coastguard Worker        T()
9836*da0073e9SAndroid Build Coastguard Worker
9837*da0073e9SAndroid Build Coastguard Worker    def test_tensor_base_new(self):
9838*da0073e9SAndroid Build Coastguard Worker
9839*da0073e9SAndroid Build Coastguard Worker        # OK to call super().__new__, see
9840*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/57421
9841*da0073e9SAndroid Build Coastguard Worker        class TestTensor(torch.Tensor):
9842*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9843*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, x, *args, **kwargs):
9844*da0073e9SAndroid Build Coastguard Worker                return super().__new__(cls, x, *args, **kwargs)
9845*da0073e9SAndroid Build Coastguard Worker
9846*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5)
9847*da0073e9SAndroid Build Coastguard Worker        test_tensor = TestTensor(x)
9848*da0073e9SAndroid Build Coastguard Worker
9849*da0073e9SAndroid Build Coastguard Worker    def test_storage_base_new(self):
9850*da0073e9SAndroid Build Coastguard Worker
9851*da0073e9SAndroid Build Coastguard Worker        # OK to call super().__new__, see
9852*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/57421
9853*da0073e9SAndroid Build Coastguard Worker        class TestStorage(torch._C.StorageBase):
9854*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9855*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, x, *args, **kwargs):
9856*da0073e9SAndroid Build Coastguard Worker                return super().__new__(cls, x, *args, **kwargs)
9857*da0073e9SAndroid Build Coastguard Worker
9858*da0073e9SAndroid Build Coastguard Worker        x = torch.UntypedStorage(5)
9859*da0073e9SAndroid Build Coastguard Worker        test_storage = TestStorage(x)
9860*da0073e9SAndroid Build Coastguard Worker
9861*da0073e9SAndroid Build Coastguard Worker    def test_pyobj_preserved(self):
9862*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2)
9863*da0073e9SAndroid Build Coastguard Worker        x.foo = 2  # put something on __dict__
9864*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2)
9865*da0073e9SAndroid Build Coastguard Worker        y.grad = x
9866*da0073e9SAndroid Build Coastguard Worker        del x  # x is dead in Python
9867*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.foo, 2)
9868*da0073e9SAndroid Build Coastguard Worker        z = y.grad  # it's live
9869*da0073e9SAndroid Build Coastguard Worker        del z  # it's dead again
9870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.foo, 2)
9871*da0073e9SAndroid Build Coastguard Worker
9872*da0073e9SAndroid Build Coastguard Worker    def test_subclass_preserved(self):
9873*da0073e9SAndroid Build Coastguard Worker        class MyTensor(torch.Tensor):
9874*da0073e9SAndroid Build Coastguard Worker            pass
9875*da0073e9SAndroid Build Coastguard Worker
9876*da0073e9SAndroid Build Coastguard Worker        x = MyTensor(torch.empty(2))
9877*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2)
9878*da0073e9SAndroid Build Coastguard Worker        y.grad = x
9879*da0073e9SAndroid Build Coastguard Worker        del x  # x is dead in Python
9880*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(y.grad), MyTensor)
9881*da0073e9SAndroid Build Coastguard Worker        z = y.grad  # it's live
9882*da0073e9SAndroid Build Coastguard Worker        del z  # it's dead again
9883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(y.grad), MyTensor)
9884*da0073e9SAndroid Build Coastguard Worker
9885*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9886*da0073e9SAndroid Build Coastguard Worker    def test_storage_dealloc(self):
9887*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
9888*da0073e9SAndroid Build Coastguard Worker        s0 = torch.UntypedStorage(10)
9889*da0073e9SAndroid Build Coastguard Worker        s1 = s0
9890*da0073e9SAndroid Build Coastguard Worker        s0._tracker = t
9891*da0073e9SAndroid Build Coastguard Worker        del t
9892*da0073e9SAndroid Build Coastguard Worker
9893*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9894*da0073e9SAndroid Build Coastguard Worker        del s0
9895*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9896*da0073e9SAndroid Build Coastguard Worker        del s1
9897*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
9898*da0073e9SAndroid Build Coastguard Worker
9899*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9900*da0073e9SAndroid Build Coastguard Worker    def test_storage_from_tensor_dealloc(self):
9901*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
9902*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
9903*da0073e9SAndroid Build Coastguard Worker        s0 = a.untyped_storage()
9904*da0073e9SAndroid Build Coastguard Worker        s0._tracker = t
9905*da0073e9SAndroid Build Coastguard Worker        del t
9906*da0073e9SAndroid Build Coastguard Worker
9907*da0073e9SAndroid Build Coastguard Worker        s1 = a.untyped_storage()
9908*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0 is s1)
9909*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(s1, '_tracker'))
9910*da0073e9SAndroid Build Coastguard Worker
9911*da0073e9SAndroid Build Coastguard Worker        del a
9912*da0073e9SAndroid Build Coastguard Worker
9913*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9914*da0073e9SAndroid Build Coastguard Worker        del s0
9915*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9916*da0073e9SAndroid Build Coastguard Worker        del s1
9917*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
9918*da0073e9SAndroid Build Coastguard Worker
9919*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9920*da0073e9SAndroid Build Coastguard Worker    def test_storage_from_tensor_dealloc_zombie(self):
9921*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
9922*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
9923*da0073e9SAndroid Build Coastguard Worker        s0 = a.untyped_storage()
9924*da0073e9SAndroid Build Coastguard Worker        s0._tracker = t
9925*da0073e9SAndroid Build Coastguard Worker        del t
9926*da0073e9SAndroid Build Coastguard Worker
9927*da0073e9SAndroid Build Coastguard Worker        s1 = a.untyped_storage()
9928*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0 is s1)
9929*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(s1, '_tracker'))
9930*da0073e9SAndroid Build Coastguard Worker
9931*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9932*da0073e9SAndroid Build Coastguard Worker        del s0
9933*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9934*da0073e9SAndroid Build Coastguard Worker        del s1
9935*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9936*da0073e9SAndroid Build Coastguard Worker        del a
9937*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
9938*da0073e9SAndroid Build Coastguard Worker
9939*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9940*da0073e9SAndroid Build Coastguard Worker    def test_storage_from_tensor_dealloc_resurrected(self):
9941*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
9942*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
9943*da0073e9SAndroid Build Coastguard Worker        s0 = a.untyped_storage()
9944*da0073e9SAndroid Build Coastguard Worker        s0._tracker = t
9945*da0073e9SAndroid Build Coastguard Worker        del t
9946*da0073e9SAndroid Build Coastguard Worker
9947*da0073e9SAndroid Build Coastguard Worker        s1 = a.untyped_storage()
9948*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0 is s1)
9949*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(s1, '_tracker'))
9950*da0073e9SAndroid Build Coastguard Worker
9951*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9952*da0073e9SAndroid Build Coastguard Worker        del s0
9953*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9954*da0073e9SAndroid Build Coastguard Worker        del s1
9955*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9956*da0073e9SAndroid Build Coastguard Worker
9957*da0073e9SAndroid Build Coastguard Worker        s0 = a.untyped_storage()
9958*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s0, torch.UntypedStorage))
9959*da0073e9SAndroid Build Coastguard Worker
9960*da0073e9SAndroid Build Coastguard Worker        del a
9961*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9962*da0073e9SAndroid Build Coastguard Worker        del s0
9963*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
9964*da0073e9SAndroid Build Coastguard Worker
9965*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9966*da0073e9SAndroid Build Coastguard Worker    def test_storage_dealloc_resurrected(self):
9967*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
9968*da0073e9SAndroid Build Coastguard Worker        s = torch.UntypedStorage(10)
9969*da0073e9SAndroid Build Coastguard Worker        s._tracker = t
9970*da0073e9SAndroid Build Coastguard Worker        del t
9971*da0073e9SAndroid Build Coastguard Worker
9972*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(s)
9973*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9974*da0073e9SAndroid Build Coastguard Worker        del s
9975*da0073e9SAndroid Build Coastguard Worker
9976*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9977*da0073e9SAndroid Build Coastguard Worker
9978*da0073e9SAndroid Build Coastguard Worker        s = a.untyped_storage()
9979*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s, torch.UntypedStorage))
9980*da0073e9SAndroid Build Coastguard Worker
9981*da0073e9SAndroid Build Coastguard Worker        del a
9982*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
9983*da0073e9SAndroid Build Coastguard Worker        del s
9984*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
9985*da0073e9SAndroid Build Coastguard Worker
9986*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
9987*da0073e9SAndroid Build Coastguard Worker    def test_storage_dealloc_subclass_zombie(self):
9988*da0073e9SAndroid Build Coastguard Worker        class MyStorage(torch.UntypedStorage):
9989*da0073e9SAndroid Build Coastguard Worker            finalized_count = 0
9990*da0073e9SAndroid Build Coastguard Worker
9991*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
9992*da0073e9SAndroid Build Coastguard Worker                MyStorage.finalized_count += 1
9993*da0073e9SAndroid Build Coastguard Worker
9994*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
9995*da0073e9SAndroid Build Coastguard Worker        s = MyStorage(10)
9996*da0073e9SAndroid Build Coastguard Worker        s._tracker = t
9997*da0073e9SAndroid Build Coastguard Worker        del t
9998*da0073e9SAndroid Build Coastguard Worker
9999*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(s)
10000*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10001*da0073e9SAndroid Build Coastguard Worker        del s
10002*da0073e9SAndroid Build Coastguard Worker
10003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MyStorage.finalized_count, 0)
10004*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10005*da0073e9SAndroid Build Coastguard Worker
10006*da0073e9SAndroid Build Coastguard Worker        del a
10007*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MyStorage.finalized_count, 1)
10008*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10009*da0073e9SAndroid Build Coastguard Worker
10010*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
10011*da0073e9SAndroid Build Coastguard Worker    def test_storage_dealloc_subclass_resurrected(self):
10012*da0073e9SAndroid Build Coastguard Worker        class MyStorage(torch.UntypedStorage):
10013*da0073e9SAndroid Build Coastguard Worker            finalized_count = 0
10014*da0073e9SAndroid Build Coastguard Worker
10015*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10016*da0073e9SAndroid Build Coastguard Worker                MyStorage.finalized_count += 1
10017*da0073e9SAndroid Build Coastguard Worker
10018*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
10019*da0073e9SAndroid Build Coastguard Worker        s = MyStorage(10)
10020*da0073e9SAndroid Build Coastguard Worker        s._tracker = t
10021*da0073e9SAndroid Build Coastguard Worker        del t
10022*da0073e9SAndroid Build Coastguard Worker
10023*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(s)
10024*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10025*da0073e9SAndroid Build Coastguard Worker        del s
10026*da0073e9SAndroid Build Coastguard Worker
10027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MyStorage.finalized_count, 0)
10028*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10029*da0073e9SAndroid Build Coastguard Worker
10030*da0073e9SAndroid Build Coastguard Worker        s = a.untyped_storage()
10031*da0073e9SAndroid Build Coastguard Worker        del a
10032*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10033*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MyStorage.finalized_count, 0)
10034*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s, MyStorage))
10035*da0073e9SAndroid Build Coastguard Worker        del s
10036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MyStorage.finalized_count, 1)
10037*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10038*da0073e9SAndroid Build Coastguard Worker
10039*da0073e9SAndroid Build Coastguard Worker    def test_tensor_slot_dealloc(self):
10040*da0073e9SAndroid Build Coastguard Worker
10041*da0073e9SAndroid Build Coastguard Worker        class SlotTensor1(torch.Tensor):
10042*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot1']
10043*da0073e9SAndroid Build Coastguard Worker
10044*da0073e9SAndroid Build Coastguard Worker        class SlotTensor2(SlotTensor1):
10045*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot2']
10046*da0073e9SAndroid Build Coastguard Worker
10047*da0073e9SAndroid Build Coastguard Worker        m1, t1 = Tracker.make()
10048*da0073e9SAndroid Build Coastguard Worker        m2, t2 = Tracker.make()
10049*da0073e9SAndroid Build Coastguard Worker        slot_tensor = SlotTensor2(torch.empty(2))
10050*da0073e9SAndroid Build Coastguard Worker        slot_tensor.slot1 = t1
10051*da0073e9SAndroid Build Coastguard Worker        slot_tensor.slot2 = t2
10052*da0073e9SAndroid Build Coastguard Worker        del t1
10053*da0073e9SAndroid Build Coastguard Worker        del t2
10054*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m1[0])
10055*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2[0])
10056*da0073e9SAndroid Build Coastguard Worker        del slot_tensor
10057*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10058*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10059*da0073e9SAndroid Build Coastguard Worker
10060*da0073e9SAndroid Build Coastguard Worker    def test_storage_slot_dealloc(self):
10061*da0073e9SAndroid Build Coastguard Worker
10062*da0073e9SAndroid Build Coastguard Worker        class SlotStorage1(torch._C.StorageBase):
10063*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot1']
10064*da0073e9SAndroid Build Coastguard Worker
10065*da0073e9SAndroid Build Coastguard Worker        class SlotStorage2(SlotStorage1):
10066*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot2']
10067*da0073e9SAndroid Build Coastguard Worker
10068*da0073e9SAndroid Build Coastguard Worker        m1, t1 = Tracker.make()
10069*da0073e9SAndroid Build Coastguard Worker        m2, t2 = Tracker.make()
10070*da0073e9SAndroid Build Coastguard Worker        slot_storage = SlotStorage2(torch.UntypedStorage(2))
10071*da0073e9SAndroid Build Coastguard Worker        slot_storage.slot1 = t1
10072*da0073e9SAndroid Build Coastguard Worker        slot_storage.slot2 = t2
10073*da0073e9SAndroid Build Coastguard Worker        del t1
10074*da0073e9SAndroid Build Coastguard Worker        del t2
10075*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m1[0])
10076*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2[0])
10077*da0073e9SAndroid Build Coastguard Worker        del slot_storage
10078*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10079*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10080*da0073e9SAndroid Build Coastguard Worker
10081*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10082*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dict_dealloc(self):
10083*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
10084*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2)
10085*da0073e9SAndroid Build Coastguard Worker        x.arf = t
10086*da0073e9SAndroid Build Coastguard Worker        del t
10087*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10088*da0073e9SAndroid Build Coastguard Worker        del x
10089*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10090*da0073e9SAndroid Build Coastguard Worker
10091*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10092*da0073e9SAndroid Build Coastguard Worker    def test_storage_dict_dealloc(self):
10093*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
10094*da0073e9SAndroid Build Coastguard Worker        x = torch.UntypedStorage(2)
10095*da0073e9SAndroid Build Coastguard Worker        x.arf = t
10096*da0073e9SAndroid Build Coastguard Worker        del t
10097*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10098*da0073e9SAndroid Build Coastguard Worker        del x
10099*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10100*da0073e9SAndroid Build Coastguard Worker
10101*da0073e9SAndroid Build Coastguard Worker    def test_tensor_finalizer_dealloc(self):
10102*da0073e9SAndroid Build Coastguard Worker        m = [False]
10103*da0073e9SAndroid Build Coastguard Worker
10104*da0073e9SAndroid Build Coastguard Worker        class FinalizerTensor(torch.Tensor):
10105*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10106*da0073e9SAndroid Build Coastguard Worker                m[0] = True
10107*da0073e9SAndroid Build Coastguard Worker
10108*da0073e9SAndroid Build Coastguard Worker        fin_tensor = FinalizerTensor(torch.empty(2))
10109*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10110*da0073e9SAndroid Build Coastguard Worker        del fin_tensor
10111*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10112*da0073e9SAndroid Build Coastguard Worker
10113*da0073e9SAndroid Build Coastguard Worker    def test_storage_finalizer_dealloc(self):
10114*da0073e9SAndroid Build Coastguard Worker        m = [False]
10115*da0073e9SAndroid Build Coastguard Worker
10116*da0073e9SAndroid Build Coastguard Worker        class FinalizerStorage(torch._C.StorageBase):
10117*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10118*da0073e9SAndroid Build Coastguard Worker                m[0] = True
10119*da0073e9SAndroid Build Coastguard Worker
10120*da0073e9SAndroid Build Coastguard Worker        fin_storage = FinalizerStorage(torch.UntypedStorage(2))
10121*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10122*da0073e9SAndroid Build Coastguard Worker        del fin_storage
10123*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10124*da0073e9SAndroid Build Coastguard Worker
10125*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10126*da0073e9SAndroid Build Coastguard Worker    def test_tensor_weakref_dealloc(self):
10127*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2)
10128*da0073e9SAndroid Build Coastguard Worker        m = [False]
10129*da0073e9SAndroid Build Coastguard Worker
10130*da0073e9SAndroid Build Coastguard Worker        def cb(r):
10131*da0073e9SAndroid Build Coastguard Worker            m[0] = True
10132*da0073e9SAndroid Build Coastguard Worker
10133*da0073e9SAndroid Build Coastguard Worker        wref = weakref.ref(x, cb)
10134*da0073e9SAndroid Build Coastguard Worker        del x
10135*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wref(), None)
10137*da0073e9SAndroid Build Coastguard Worker
10138*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10139*da0073e9SAndroid Build Coastguard Worker    def test_storage_weakref_dealloc(self):
10140*da0073e9SAndroid Build Coastguard Worker
10141*da0073e9SAndroid Build Coastguard Worker        x = torch.UntypedStorage(2)
10142*da0073e9SAndroid Build Coastguard Worker        m = [False]
10143*da0073e9SAndroid Build Coastguard Worker
10144*da0073e9SAndroid Build Coastguard Worker        def cb(r):
10145*da0073e9SAndroid Build Coastguard Worker            m[0] = True
10146*da0073e9SAndroid Build Coastguard Worker
10147*da0073e9SAndroid Build Coastguard Worker        wref = weakref.ref(x, cb)
10148*da0073e9SAndroid Build Coastguard Worker        del x
10149*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10150*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wref(), None)
10151*da0073e9SAndroid Build Coastguard Worker
10152*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10153*da0073e9SAndroid Build Coastguard Worker    def test_tensor_cycle_via_dict(self):
10154*da0073e9SAndroid Build Coastguard Worker        m1, t1 = Tracker.make()
10155*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2)
10156*da0073e9SAndroid Build Coastguard Worker        x._tracker = t1
10157*da0073e9SAndroid Build Coastguard Worker        del t1
10158*da0073e9SAndroid Build Coastguard Worker
10159*da0073e9SAndroid Build Coastguard Worker        m2, t2 = Tracker.make()
10160*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2)
10161*da0073e9SAndroid Build Coastguard Worker        y._tracker = t2
10162*da0073e9SAndroid Build Coastguard Worker        del t2
10163*da0073e9SAndroid Build Coastguard Worker
10164*da0073e9SAndroid Build Coastguard Worker        x._loop = y
10165*da0073e9SAndroid Build Coastguard Worker        y._loop = x
10166*da0073e9SAndroid Build Coastguard Worker
10167*da0073e9SAndroid Build Coastguard Worker        # C++ reference should keep the cycle live!
10168*da0073e9SAndroid Build Coastguard Worker        # This exercise THPVariable_subtype_traverse
10169*da0073e9SAndroid Build Coastguard Worker        # NB: Because z.grad is a reference done entirely in C++, cycles
10170*da0073e9SAndroid Build Coastguard Worker        # involving it directly are NOT broken by Python GC; you've
10171*da0073e9SAndroid Build Coastguard Worker        # set up a good old C++ reference cycle which we cannot safely
10172*da0073e9SAndroid Build Coastguard Worker        # break (because C++ references are allowed to be accessed
10173*da0073e9SAndroid Build Coastguard Worker        # multithreaded-ly) (TODO: except maybe if you can prove that
10174*da0073e9SAndroid Build Coastguard Worker        # only Python has access to the C++ object, in which case you can
10175*da0073e9SAndroid Build Coastguard Worker        # also prove that no multithreaded access occurs)
10176*da0073e9SAndroid Build Coastguard Worker        z = torch.empty(2)
10177*da0073e9SAndroid Build Coastguard Worker        z.grad = x
10178*da0073e9SAndroid Build Coastguard Worker
10179*da0073e9SAndroid Build Coastguard Worker        del x
10180*da0073e9SAndroid Build Coastguard Worker        del y
10181*da0073e9SAndroid Build Coastguard Worker
10182*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10183*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m1[0])
10184*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2[0])
10185*da0073e9SAndroid Build Coastguard Worker
10186*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
10187*da0073e9SAndroid Build Coastguard Worker            del z
10188*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m1[0])
10189*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m2[0])
10190*da0073e9SAndroid Build Coastguard Worker
10191*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10192*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10193*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10194*da0073e9SAndroid Build Coastguard Worker
10195*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10196*da0073e9SAndroid Build Coastguard Worker    def test_storage_cycle_via_dict(self):
10197*da0073e9SAndroid Build Coastguard Worker        m1, t1 = Tracker.make()
10198*da0073e9SAndroid Build Coastguard Worker        x = torch.UntypedStorage(2)
10199*da0073e9SAndroid Build Coastguard Worker        x._tracker = t1
10200*da0073e9SAndroid Build Coastguard Worker        del t1
10201*da0073e9SAndroid Build Coastguard Worker
10202*da0073e9SAndroid Build Coastguard Worker        m2, t2 = Tracker.make()
10203*da0073e9SAndroid Build Coastguard Worker        y = torch.UntypedStorage(2)
10204*da0073e9SAndroid Build Coastguard Worker        y._tracker = t2
10205*da0073e9SAndroid Build Coastguard Worker        del t2
10206*da0073e9SAndroid Build Coastguard Worker
10207*da0073e9SAndroid Build Coastguard Worker        x._loop = y
10208*da0073e9SAndroid Build Coastguard Worker        y._loop = x
10209*da0073e9SAndroid Build Coastguard Worker
10210*da0073e9SAndroid Build Coastguard Worker        # C++ reference should keep the cycle live!
10211*da0073e9SAndroid Build Coastguard Worker        # This exercise THPVariable_subtype_traverse
10212*da0073e9SAndroid Build Coastguard Worker        # NB: Because z.grad is a reference done entirely in C++, cycles
10213*da0073e9SAndroid Build Coastguard Worker        # involving it directly are NOT broken by Python GC; you've
10214*da0073e9SAndroid Build Coastguard Worker        # set up a good old C++ reference cycle which we cannot safely
10215*da0073e9SAndroid Build Coastguard Worker        # break (because C++ references are allowed to be accessed
10216*da0073e9SAndroid Build Coastguard Worker        # multithreaded-ly) (TODO: except maybe if you can prove that
10217*da0073e9SAndroid Build Coastguard Worker        # only Python has access to the C++ object, in which case you can
10218*da0073e9SAndroid Build Coastguard Worker        # also prove that no multithreaded access occurs)
10219*da0073e9SAndroid Build Coastguard Worker        z = torch.UntypedStorage(2)
10220*da0073e9SAndroid Build Coastguard Worker        z.grad = x
10221*da0073e9SAndroid Build Coastguard Worker
10222*da0073e9SAndroid Build Coastguard Worker        del x
10223*da0073e9SAndroid Build Coastguard Worker        del y
10224*da0073e9SAndroid Build Coastguard Worker
10225*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10226*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m1[0])
10227*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2[0])
10228*da0073e9SAndroid Build Coastguard Worker
10229*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
10230*da0073e9SAndroid Build Coastguard Worker            del z
10231*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m1[0])
10232*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m2[0])
10233*da0073e9SAndroid Build Coastguard Worker
10234*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10235*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10236*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10237*da0073e9SAndroid Build Coastguard Worker
10238*da0073e9SAndroid Build Coastguard Worker    def test_tensor_cycle_via_slots(self):
10239*da0073e9SAndroid Build Coastguard Worker        m1 = [False]
10240*da0073e9SAndroid Build Coastguard Worker        m2 = [False]
10241*da0073e9SAndroid Build Coastguard Worker
10242*da0073e9SAndroid Build Coastguard Worker        class SlotTensor1(torch.Tensor):
10243*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot1']
10244*da0073e9SAndroid Build Coastguard Worker
10245*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10246*da0073e9SAndroid Build Coastguard Worker                m1[0] = True
10247*da0073e9SAndroid Build Coastguard Worker
10248*da0073e9SAndroid Build Coastguard Worker        class SlotTensor2(SlotTensor1):
10249*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot2']
10250*da0073e9SAndroid Build Coastguard Worker
10251*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10252*da0073e9SAndroid Build Coastguard Worker                m2[0] = True
10253*da0073e9SAndroid Build Coastguard Worker
10254*da0073e9SAndroid Build Coastguard Worker        x = SlotTensor1(torch.empty(2))
10255*da0073e9SAndroid Build Coastguard Worker        y = SlotTensor2(torch.empty(2))
10256*da0073e9SAndroid Build Coastguard Worker
10257*da0073e9SAndroid Build Coastguard Worker        x.slot1 = y
10258*da0073e9SAndroid Build Coastguard Worker        y.slot2 = x
10259*da0073e9SAndroid Build Coastguard Worker
10260*da0073e9SAndroid Build Coastguard Worker        del x
10261*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
10262*da0073e9SAndroid Build Coastguard Worker            del y
10263*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m1[0])
10264*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m2[0])
10265*da0073e9SAndroid Build Coastguard Worker
10266*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10267*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10268*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10269*da0073e9SAndroid Build Coastguard Worker
10270*da0073e9SAndroid Build Coastguard Worker    def test_storage_cycle_via_slots(self):
10271*da0073e9SAndroid Build Coastguard Worker        m1 = [False]
10272*da0073e9SAndroid Build Coastguard Worker        m2 = [False]
10273*da0073e9SAndroid Build Coastguard Worker
10274*da0073e9SAndroid Build Coastguard Worker        class SlotStorage1(torch._C.StorageBase):
10275*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot1']
10276*da0073e9SAndroid Build Coastguard Worker
10277*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10278*da0073e9SAndroid Build Coastguard Worker                m1[0] = True
10279*da0073e9SAndroid Build Coastguard Worker
10280*da0073e9SAndroid Build Coastguard Worker        class SlotStorage2(SlotStorage1):
10281*da0073e9SAndroid Build Coastguard Worker            __slots__ = ['slot2']
10282*da0073e9SAndroid Build Coastguard Worker
10283*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
10284*da0073e9SAndroid Build Coastguard Worker                m2[0] = True
10285*da0073e9SAndroid Build Coastguard Worker
10286*da0073e9SAndroid Build Coastguard Worker        x = SlotStorage1(torch.UntypedStorage(2))
10287*da0073e9SAndroid Build Coastguard Worker        y = SlotStorage2(torch.UntypedStorage(2))
10288*da0073e9SAndroid Build Coastguard Worker
10289*da0073e9SAndroid Build Coastguard Worker        x.slot1 = y
10290*da0073e9SAndroid Build Coastguard Worker        y.slot2 = x
10291*da0073e9SAndroid Build Coastguard Worker
10292*da0073e9SAndroid Build Coastguard Worker        del x
10293*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
10294*da0073e9SAndroid Build Coastguard Worker            del y
10295*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m1[0])
10296*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m2[0])
10297*da0073e9SAndroid Build Coastguard Worker
10298*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10299*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10300*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10301*da0073e9SAndroid Build Coastguard Worker
10302*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
10303*da0073e9SAndroid Build Coastguard Worker    def test_storage_preserve_nonhermetic_in_hermetic_context(self):
10304*da0073e9SAndroid Build Coastguard Worker        from torch.library import Library, impl
10305*da0073e9SAndroid Build Coastguard Worker        global _my_storage
10306*da0073e9SAndroid Build Coastguard Worker
10307*da0073e9SAndroid Build Coastguard Worker        my_lib = Library("my_lib", "DEF")  # noqa: TOR901
10308*da0073e9SAndroid Build Coastguard Worker        my_lib.define('my_func() -> None')
10309*da0073e9SAndroid Build Coastguard Worker
10310*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.])
10311*da0073e9SAndroid Build Coastguard Worker        _my_storage = a.untyped_storage()
10312*da0073e9SAndroid Build Coastguard Worker
10313*da0073e9SAndroid Build Coastguard Worker        m, t = Tracker.make()
10314*da0073e9SAndroid Build Coastguard Worker        _my_storage._tracker = t
10315*da0073e9SAndroid Build Coastguard Worker        del t
10316*da0073e9SAndroid Build Coastguard Worker
10317*da0073e9SAndroid Build Coastguard Worker        @impl(my_lib, 'my_func', '')
10318*da0073e9SAndroid Build Coastguard Worker        def my_func():
10319*da0073e9SAndroid Build Coastguard Worker            global _my_storage
10320*da0073e9SAndroid Build Coastguard Worker            del _my_storage
10321*da0073e9SAndroid Build Coastguard Worker
10322*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10323*da0073e9SAndroid Build Coastguard Worker        torch.ops.my_lib.my_func()
10324*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m[0])
10325*da0073e9SAndroid Build Coastguard Worker
10326*da0073e9SAndroid Build Coastguard Worker        s = a.untyped_storage()
10327*da0073e9SAndroid Build Coastguard Worker        del a
10328*da0073e9SAndroid Build Coastguard Worker        del s
10329*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m[0])
10330*da0073e9SAndroid Build Coastguard Worker
10331*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_autograd?
10332*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
10333*da0073e9SAndroid Build Coastguard Worker    def test_backward_hooks_traverse(self):
10334*da0073e9SAndroid Build Coastguard Worker        m1, t1 = Tracker.make()
10335*da0073e9SAndroid Build Coastguard Worker        m2, t2 = Tracker.make()
10336*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, requires_grad=True)
10337*da0073e9SAndroid Build Coastguard Worker        x._tracker = t1
10338*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2, requires_grad=True)
10339*da0073e9SAndroid Build Coastguard Worker        y._tracker = t2
10340*da0073e9SAndroid Build Coastguard Worker        del t1
10341*da0073e9SAndroid Build Coastguard Worker        del t2
10342*da0073e9SAndroid Build Coastguard Worker
10343*da0073e9SAndroid Build Coastguard Worker        # this hits a special setter, it's not just a __dict__ entry
10344*da0073e9SAndroid Build Coastguard Worker        x._backward_hooks = y
10345*da0073e9SAndroid Build Coastguard Worker        y._backward_hooks = x
10346*da0073e9SAndroid Build Coastguard Worker
10347*da0073e9SAndroid Build Coastguard Worker        del x
10348*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
10349*da0073e9SAndroid Build Coastguard Worker            del y
10350*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m1[0])
10351*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m2[0])
10352*da0073e9SAndroid Build Coastguard Worker
10353*da0073e9SAndroid Build Coastguard Worker        gc.collect()
10354*da0073e9SAndroid Build Coastguard Worker
10355*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1[0])
10356*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2[0])
10357*da0073e9SAndroid Build Coastguard Worker
10358*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10359*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dead_weak_ref(self):
10360*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2)
10361*da0073e9SAndroid Build Coastguard Worker        w_x = weakref.ref(x)
10362*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2)
10363*da0073e9SAndroid Build Coastguard Worker        y.grad = x
10364*da0073e9SAndroid Build Coastguard Worker        del x
10365*da0073e9SAndroid Build Coastguard Worker
10366*da0073e9SAndroid Build Coastguard Worker        x = w_x()
10367*da0073e9SAndroid Build Coastguard Worker        # Ideally, x would keep the tensor live.  But CPython doesn't
10368*da0073e9SAndroid Build Coastguard Worker        # provide enough hooks to do this.  So it will go dead and x
10369*da0073e9SAndroid Build Coastguard Worker        # will transmute into an undefined tensor.  Not great, but the
10370*da0073e9SAndroid Build Coastguard Worker        # best we can do.
10371*da0073e9SAndroid Build Coastguard Worker        del y
10372*da0073e9SAndroid Build Coastguard Worker
10373*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.sigmoid())
10374*da0073e9SAndroid Build Coastguard Worker
10375*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10376*da0073e9SAndroid Build Coastguard Worker    def test_storage_dead_weak_ref(self):
10377*da0073e9SAndroid Build Coastguard Worker        x = torch.UntypedStorage(2)
10378*da0073e9SAndroid Build Coastguard Worker        w_x = weakref.ref(x)
10379*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(x)
10380*da0073e9SAndroid Build Coastguard Worker        del x
10381*da0073e9SAndroid Build Coastguard Worker
10382*da0073e9SAndroid Build Coastguard Worker        x = w_x()
10383*da0073e9SAndroid Build Coastguard Worker        # Ideally, x would keep the storage live.  But CPython doesn't
10384*da0073e9SAndroid Build Coastguard Worker        # provide enough hooks to do this.  So it will go dead and x
10385*da0073e9SAndroid Build Coastguard Worker        # will transmute into storage with null StorageImpl. Not great, but the
10386*da0073e9SAndroid Build Coastguard Worker        # best we can do.
10387*da0073e9SAndroid Build Coastguard Worker        del y
10388*da0073e9SAndroid Build Coastguard Worker
10389*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
10390*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
10391*da0073e9SAndroid Build Coastguard Worker
10392*da0073e9SAndroid Build Coastguard Worker    def test_tensor_resurrected_weak_ref(self):
10393*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2)
10394*da0073e9SAndroid Build Coastguard Worker        w_x = weakref.ref(x)
10395*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2)
10396*da0073e9SAndroid Build Coastguard Worker        y.grad = x
10397*da0073e9SAndroid Build Coastguard Worker        del x
10398*da0073e9SAndroid Build Coastguard Worker
10399*da0073e9SAndroid Build Coastguard Worker        x = w_x()
10400*da0073e9SAndroid Build Coastguard Worker        # Use this to manually fix weak references after dereferencing them
10401*da0073e9SAndroid Build Coastguard Worker        x._fix_weakref()
10402*da0073e9SAndroid Build Coastguard Worker        del y
10403*da0073e9SAndroid Build Coastguard Worker        x.sigmoid()
10404*da0073e9SAndroid Build Coastguard Worker
10405*da0073e9SAndroid Build Coastguard Worker    def test_storage_resurrected_weak_ref(self):
10406*da0073e9SAndroid Build Coastguard Worker        x = torch.UntypedStorage(2)
10407*da0073e9SAndroid Build Coastguard Worker        w_x = weakref.ref(x)
10408*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(x)
10409*da0073e9SAndroid Build Coastguard Worker        del x
10410*da0073e9SAndroid Build Coastguard Worker
10411*da0073e9SAndroid Build Coastguard Worker        x = w_x()
10412*da0073e9SAndroid Build Coastguard Worker        # Use this to manually fix weak reference after dereferencing them
10413*da0073e9SAndroid Build Coastguard Worker        x._fix_weakref()
10414*da0073e9SAndroid Build Coastguard Worker        del y
10415*da0073e9SAndroid Build Coastguard Worker        x.float()
10416*da0073e9SAndroid Build Coastguard Worker
10417*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10418*da0073e9SAndroid Build Coastguard Worker    def test_tensor_fix_weakref_no_leak(self):
10419*da0073e9SAndroid Build Coastguard Worker        import weakref
10420*da0073e9SAndroid Build Coastguard Worker
10421*da0073e9SAndroid Build Coastguard Worker        called = False
10422*da0073e9SAndroid Build Coastguard Worker
10423*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1)
10424*da0073e9SAndroid Build Coastguard Worker
10425*da0073e9SAndroid Build Coastguard Worker        def callback(w):
10426*da0073e9SAndroid Build Coastguard Worker            nonlocal called
10427*da0073e9SAndroid Build Coastguard Worker            called = True
10428*da0073e9SAndroid Build Coastguard Worker        wa = weakref.ref(a, callback)
10429*da0073e9SAndroid Build Coastguard Worker        a._fix_weakref()
10430*da0073e9SAndroid Build Coastguard Worker        del a
10431*da0073e9SAndroid Build Coastguard Worker
10432*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
10433*da0073e9SAndroid Build Coastguard Worker
10434*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
10435*da0073e9SAndroid Build Coastguard Worker    def test_storage_fix_weakref_no_leak(self):
10436*da0073e9SAndroid Build Coastguard Worker        import weakref
10437*da0073e9SAndroid Build Coastguard Worker
10438*da0073e9SAndroid Build Coastguard Worker        called = False
10439*da0073e9SAndroid Build Coastguard Worker
10440*da0073e9SAndroid Build Coastguard Worker        a = torch.UntypedStorage(1)
10441*da0073e9SAndroid Build Coastguard Worker
10442*da0073e9SAndroid Build Coastguard Worker        def callback(w):
10443*da0073e9SAndroid Build Coastguard Worker            nonlocal called
10444*da0073e9SAndroid Build Coastguard Worker            called = True
10445*da0073e9SAndroid Build Coastguard Worker        wa = weakref.ref(a, callback)
10446*da0073e9SAndroid Build Coastguard Worker        a._fix_weakref()
10447*da0073e9SAndroid Build Coastguard Worker        del a
10448*da0073e9SAndroid Build Coastguard Worker
10449*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
10450*da0073e9SAndroid Build Coastguard Worker
10451*da0073e9SAndroid Build Coastguard Worker    # FIXME: move to test_linalg
10452*da0073e9SAndroid Build Coastguard Worker    @torch.inference_mode()
10453*da0073e9SAndroid Build Coastguard Worker    def test_bmm_multithreaded(self):
10454*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
10455*da0073e9SAndroid Build Coastguard Worker        num_threads = torch.get_num_threads()
10456*da0073e9SAndroid Build Coastguard Worker
10457*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(4)
10458*da0073e9SAndroid Build Coastguard Worker        batch_sizes = [1, 10]
10459*da0073e9SAndroid Build Coastguard Worker        M, N, O = 23, 8, 12
10460*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
10461*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype
10462*da0073e9SAndroid Build Coastguard Worker
10463*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
10464*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
10465*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
10466*da0073e9SAndroid Build Coastguard Worker
10467*da0073e9SAndroid Build Coastguard Worker        def generate_inputs(num_batches):
10468*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
10469*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
10470*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
10471*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
10472*da0073e9SAndroid Build Coastguard Worker                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
10473*da0073e9SAndroid Build Coastguard Worker                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
10474*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
10475*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
10476*da0073e9SAndroid Build Coastguard Worker            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
10477*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
10478*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
10479*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
10480*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
10481*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
10482*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
10483*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
10484*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
10485*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
10486*da0073e9SAndroid Build Coastguard Worker                b1 = torch.randn(shape1, dtype=dtype, device=device)
10487*da0073e9SAndroid Build Coastguard Worker                b2 = torch.randn(shape2, dtype=dtype, device=device)
10488*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
10489*da0073e9SAndroid Build Coastguard Worker
10490*da0073e9SAndroid Build Coastguard Worker        try:
10491*da0073e9SAndroid Build Coastguard Worker            for num_batches in batch_sizes:
10492*da0073e9SAndroid Build Coastguard Worker                for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
10493*da0073e9SAndroid Build Coastguard Worker                    res1 = torch.bmm(b1, b2)
10494*da0073e9SAndroid Build Coastguard Worker                    res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
10495*da0073e9SAndroid Build Coastguard Worker                        .permute(perm3).contiguous().permute(invert_perm(perm3))
10496*da0073e9SAndroid Build Coastguard Worker                    torch.bmm(b1, b2, out=res2)
10497*da0073e9SAndroid Build Coastguard Worker                    expect = torch.from_numpy(
10498*da0073e9SAndroid Build Coastguard Worker                        b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
10499*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expect, res1)
10500*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expect, res2)
10501*da0073e9SAndroid Build Coastguard Worker        finally:
10502*da0073e9SAndroid Build Coastguard Worker            torch.set_num_threads(num_threads)
10503*da0073e9SAndroid Build Coastguard Worker
10504*da0073e9SAndroid Build Coastguard Worker    def test_conj_neg_tolist(self):
10505*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, dtype=torch.cfloat)
10506*da0073e9SAndroid Build Coastguard Worker        y1 = x.conj()
10507*da0073e9SAndroid Build Coastguard Worker        y1_expect = x.conj_physical()
10508*da0073e9SAndroid Build Coastguard Worker        y2 = y1.imag
10509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y1_expect.tolist())
10510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y2, y1_expect.imag.tolist())
10511*da0073e9SAndroid Build Coastguard Worker
10512*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(torch.backends.cuda.is_built(), "Skipped for cuda-enabled build")
10513*da0073e9SAndroid Build Coastguard Worker    def test_no_cuda_monkeypatch(self):
10514*da0073e9SAndroid Build Coastguard Worker        # Note that this is not in test_cuda.py as this whole file is skipped when cuda
10515*da0073e9SAndroid Build Coastguard Worker        # is not available.
10516*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Stream"):
10517*da0073e9SAndroid Build Coastguard Worker            torch.cuda.Stream()
10518*da0073e9SAndroid Build Coastguard Worker
10519*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"):
10520*da0073e9SAndroid Build Coastguard Worker            torch.cuda.Event()
10521*da0073e9SAndroid Build Coastguard Worker
10522*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class CUDAGraph"):
10523*da0073e9SAndroid Build Coastguard Worker            torch.cuda.graphs.CUDAGraph()
10524*da0073e9SAndroid Build Coastguard Worker
10525*da0073e9SAndroid Build Coastguard Worker    def test_tensor_where_scalar(self):
10526*da0073e9SAndroid Build Coastguard Worker
10527*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(4.0)
10528*da0073e9SAndroid Build Coastguard Worker        not_zero = 0.001
10529*da0073e9SAndroid Build Coastguard Worker
10530*da0073e9SAndroid Build Coastguard Worker        # b is generated through torch.where function with not_zero being a scalar parameter
10531*da0073e9SAndroid Build Coastguard Worker        b = torch.where(a != 0, a, not_zero)
10532*da0073e9SAndroid Build Coastguard Worker        # c is generated through Tensor.where method with not_zero being a scalar parameter
10533*da0073e9SAndroid Build Coastguard Worker        c = a.where(a != 0, not_zero)
10534*da0073e9SAndroid Build Coastguard Worker
10535*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, c)
10536*da0073e9SAndroid Build Coastguard Worker
10537*da0073e9SAndroid Build Coastguard Worker    def test_data_ptr_of_empty_tensor_with_storage(self):
10538*da0073e9SAndroid Build Coastguard Worker        t = torch.empty((2, 2))
10539*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t.data_ptr(), 0)
10540*da0073e9SAndroid Build Coastguard Worker        t.resize_((0, 2))
10541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.data_ptr(), 0)
10542*da0073e9SAndroid Build Coastguard Worker
10543*da0073e9SAndroid Build Coastguard Worker    def test_data_ptr_of_empty_view_with_storage(self):
10544*da0073e9SAndroid Build Coastguard Worker        t = torch.empty((2, 2))
10545*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t.data_ptr(), 0)
10546*da0073e9SAndroid Build Coastguard Worker        t2 = t[0:0].view(0, 1)
10547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2.data_ptr(), 0)
10548*da0073e9SAndroid Build Coastguard Worker
10549*da0073e9SAndroid Build Coastguard Worker    def test_size_stride(self) -> None:
10550*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(2, 3, dtype=torch.float32)
10551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.size(0), 2)
10552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.size(dim=None), torch.Size([2, 3]))
10553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.stride(dim=None), torch.Size([3, 1]))
10554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.t().stride(), torch.Size([1, 3]))
10555*da0073e9SAndroid Build Coastguard Worker
10556*da0073e9SAndroid Build Coastguard Worker    def test_invalid_arg_error_handling(self) -> None:
10557*da0073e9SAndroid Build Coastguard Worker        """ Tests that errors from old TH functions are propagated back """
10558*da0073e9SAndroid Build Coastguard Worker        for invalid_val in [-1, 2**65]:
10559*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
10560*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val))
10561*da0073e9SAndroid Build Coastguard Worker
10562*da0073e9SAndroid Build Coastguard Worker    def _get_tensor_prop(self, t):
10563*da0073e9SAndroid Build Coastguard Worker        preserved = (
10564*da0073e9SAndroid Build Coastguard Worker            id(t),
10565*da0073e9SAndroid Build Coastguard Worker            # Refcount values get modified by Dynamo resume frames
10566*da0073e9SAndroid Build Coastguard Worker            0 if TEST_WITH_TORCHDYNAMO else sys.getrefcount(t),
10567*da0073e9SAndroid Build Coastguard Worker        )
10568*da0073e9SAndroid Build Coastguard Worker        slotnames = copyreg._slotnames(t.__class__)
10569*da0073e9SAndroid Build Coastguard Worker        moved = (
10570*da0073e9SAndroid Build Coastguard Worker            slotnames,
10571*da0073e9SAndroid Build Coastguard Worker            id(t.__dict__),
10572*da0073e9SAndroid Build Coastguard Worker            tuple(t.__dict__.keys()),
10573*da0073e9SAndroid Build Coastguard Worker            [getattr(t, name, None) for name in slotnames]
10574*da0073e9SAndroid Build Coastguard Worker        )
10575*da0073e9SAndroid Build Coastguard Worker        return preserved, moved
10576*da0073e9SAndroid Build Coastguard Worker
10577*da0073e9SAndroid Build Coastguard Worker    def _checked_swap(self, t1, t2):
10578*da0073e9SAndroid Build Coastguard Worker        t1_pres, t1_moved = self._get_tensor_prop(t1)
10579*da0073e9SAndroid Build Coastguard Worker        t2_pres, t2_moved = self._get_tensor_prop(t2)
10580*da0073e9SAndroid Build Coastguard Worker
10581*da0073e9SAndroid Build Coastguard Worker        torch.utils.swap_tensors(t1, t2)
10582*da0073e9SAndroid Build Coastguard Worker
10583*da0073e9SAndroid Build Coastguard Worker        new_t1_pres, new_t1_moved = self._get_tensor_prop(t1)
10584*da0073e9SAndroid Build Coastguard Worker        new_t2_pres, new_t2_moved = self._get_tensor_prop(t2)
10585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1_pres, new_t1_pres)
10586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2_pres, new_t2_pres)
10587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1_moved, new_t2_moved)
10588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2_moved, new_t1_moved)
10589*da0073e9SAndroid Build Coastguard Worker
10590*da0073e9SAndroid Build Coastguard Worker        # tests that PyObject slots on TensorImpl are correctly swapped by
10591*da0073e9SAndroid Build Coastguard Worker        # checking that when the function applied on a swapped tensor is
10592*da0073e9SAndroid Build Coastguard Worker        # returns doesn't change the TensorImpl, the returned value (which is
10593*da0073e9SAndroid Build Coastguard Worker        # given by returning the reference to the PyObject in the TensorImpl's
10594*da0073e9SAndroid Build Coastguard Worker        # PyObjectSlot) is still correct
10595*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(t1.fill_(0.5)), id(t1))
10596*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(t2.fill_(0.5)), id(t2))
10597*da0073e9SAndroid Build Coastguard Worker
10598*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo adds weakrefs")
10599*da0073e9SAndroid Build Coastguard Worker    def test_swap_basic(self):
10600*da0073e9SAndroid Build Coastguard Worker        ts = [
10601*da0073e9SAndroid Build Coastguard Worker            torch.rand(2),
10602*da0073e9SAndroid Build Coastguard Worker            torch.rand(3, 3),
10603*da0073e9SAndroid Build Coastguard Worker            torch.empty(3, dtype=torch.int),
10604*da0073e9SAndroid Build Coastguard Worker            TwoTensor(torch.rand(4), torch.rand(4))
10605*da0073e9SAndroid Build Coastguard Worker        ]
10606*da0073e9SAndroid Build Coastguard Worker
10607*da0073e9SAndroid Build Coastguard Worker        for t1, t2 in itertools.combinations(ts, 2):
10608*da0073e9SAndroid Build Coastguard Worker            t1 = t1.clone()
10609*da0073e9SAndroid Build Coastguard Worker            t2 = t2.clone()
10610*da0073e9SAndroid Build Coastguard Worker            t2.foo = "bar"
10611*da0073e9SAndroid Build Coastguard Worker            holder = []
10612*da0073e9SAndroid Build Coastguard Worker            holder.append(t1)
10613*da0073e9SAndroid Build Coastguard Worker
10614*da0073e9SAndroid Build Coastguard Worker            self._checked_swap(t1, t2)
10615*da0073e9SAndroid Build Coastguard Worker
10616*da0073e9SAndroid Build Coastguard Worker            self.assertIs(holder[0], t1)
10617*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1.foo, "bar")
10618*da0073e9SAndroid Build Coastguard Worker
10619*da0073e9SAndroid Build Coastguard Worker            if t1.is_floating_point():
10620*da0073e9SAndroid Build Coastguard Worker                t3 = t1.clone().detach().requires_grad_(True)
10621*da0073e9SAndroid Build Coastguard Worker                out = t3 * 2
10622*da0073e9SAndroid Build Coastguard Worker                torch.utils.swap_tensors(t3, t2)
10623*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"):
10624*da0073e9SAndroid Build Coastguard Worker                    out.sum().backward()
10625*da0073e9SAndroid Build Coastguard Worker
10626*da0073e9SAndroid Build Coastguard Worker            wr = weakref.ref(t1)
10627*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "has weakref"):
10628*da0073e9SAndroid Build Coastguard Worker                torch.utils.swap_tensors(t1, t2)
10629*da0073e9SAndroid Build Coastguard Worker
10630*da0073e9SAndroid Build Coastguard Worker
10631*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo adds weakrefs")
10632*da0073e9SAndroid Build Coastguard Worker    def test_swap_fail_slots(self):
10633*da0073e9SAndroid Build Coastguard Worker        class MyTwoTensor(TwoTensor):
10634*da0073e9SAndroid Build Coastguard Worker            __slots__ = ("a", "b")
10635*da0073e9SAndroid Build Coastguard Worker
10636*da0073e9SAndroid Build Coastguard Worker        class MyTwoTensor2(TwoTensor):
10637*da0073e9SAndroid Build Coastguard Worker            __slots__ = ("b", "a")
10638*da0073e9SAndroid Build Coastguard Worker
10639*da0073e9SAndroid Build Coastguard Worker        class MyTwoTensor3(TwoTensor):
10640*da0073e9SAndroid Build Coastguard Worker            __slots__ = ("a", "b", "c", "d")
10641*da0073e9SAndroid Build Coastguard Worker
10642*da0073e9SAndroid Build Coastguard Worker        class MyTwoTensor4(TwoTensor):
10643*da0073e9SAndroid Build Coastguard Worker            __slots__ = ("a", "c")
10644*da0073e9SAndroid Build Coastguard Worker
10645*da0073e9SAndroid Build Coastguard Worker
10646*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(4)
10647*da0073e9SAndroid Build Coastguard Worker        t2 = TwoTensor(torch.rand(4), torch.rand(4))
10648*da0073e9SAndroid Build Coastguard Worker        t3 = MyTwoTensor(torch.rand(4), torch.rand(4))
10649*da0073e9SAndroid Build Coastguard Worker        t4 = MyTwoTensor(torch.rand(4), torch.rand(4))
10650*da0073e9SAndroid Build Coastguard Worker        t5 = MyTwoTensor2(torch.rand(4), torch.rand(4))
10651*da0073e9SAndroid Build Coastguard Worker        t6 = MyTwoTensor3(torch.rand(4), torch.rand(4))
10652*da0073e9SAndroid Build Coastguard Worker        t7 = MyTwoTensor3(torch.rand(4), torch.rand(4))
10653*da0073e9SAndroid Build Coastguard Worker        t8 = MyTwoTensor4(torch.rand(4), torch.rand(4))
10654*da0073e9SAndroid Build Coastguard Worker
10655*da0073e9SAndroid Build Coastguard Worker        self._checked_swap(t1, t2)
10656*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10657*da0073e9SAndroid Build Coastguard Worker            torch.utils.swap_tensors(t1, t3)
10658*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10659*da0073e9SAndroid Build Coastguard Worker            torch.utils.swap_tensors(t2, t3)
10660*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10661*da0073e9SAndroid Build Coastguard Worker            torch.utils.swap_tensors(t2, t8)
10662*da0073e9SAndroid Build Coastguard Worker        self._checked_swap(t3, t4)
10663*da0073e9SAndroid Build Coastguard Worker        self._checked_swap(t3, t5)
10664*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"):
10665*da0073e9SAndroid Build Coastguard Worker            torch.utils.swap_tensors(t3, t6)
10666*da0073e9SAndroid Build Coastguard Worker        t3.c = "foo"
10667*da0073e9SAndroid Build Coastguard Worker        t4.d = "bar"
10668*da0073e9SAndroid Build Coastguard Worker        self._checked_swap(t3, t4)
10669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t4.c, "foo")
10670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t3.d, "bar")
10671*da0073e9SAndroid Build Coastguard Worker        t6.c = "cat"
10672*da0073e9SAndroid Build Coastguard Worker        t7.d = "dog"
10673*da0073e9SAndroid Build Coastguard Worker        self._checked_swap(t6, t7)
10674*da0073e9SAndroid Build Coastguard Worker
10675*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(torch.cuda.is_available(), "Test specific for CPU")
10676*da0073e9SAndroid Build Coastguard Worker    def test_bf16_supported_on_cpu(self):
10677*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.cuda.is_bf16_supported())
10678*da0073e9SAndroid Build Coastguard Worker
10679*da0073e9SAndroid Build Coastguard Worker
10680*da0073e9SAndroid Build Coastguard Worker# The following block extends TestTorch with negative dim wrapping tests
10681*da0073e9SAndroid Build Coastguard Worker# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests
10682*da0073e9SAndroid Build Coastguard Worker# Functions to test negative dimension wrapping
10683*da0073e9SAndroid Build Coastguard WorkerMETHOD = 1
10684*da0073e9SAndroid Build Coastguard WorkerINPLACE_METHOD = 2
10685*da0073e9SAndroid Build Coastguard WorkerFUNCTIONAL = 4
10686*da0073e9SAndroid Build Coastguard WorkerDIM_ARG: None = None
10687*da0073e9SAndroid Build Coastguard Worker
10688*da0073e9SAndroid Build Coastguard Workerdef make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
10689*da0073e9SAndroid Build Coastguard Worker    def neg_dim_test(self):
10690*da0073e9SAndroid Build Coastguard Worker        if isinstance(tensor_arg, list):
10691*da0073e9SAndroid Build Coastguard Worker            assert METHOD not in types and INPLACE_METHOD not in types
10692*da0073e9SAndroid Build Coastguard Worker            x = [torch.randn(arg) for arg in tensor_arg]
10693*da0073e9SAndroid Build Coastguard Worker            ndim = len(tensor_arg[-1])
10694*da0073e9SAndroid Build Coastguard Worker        else:
10695*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(*tensor_arg)
10696*da0073e9SAndroid Build Coastguard Worker            ndim = len(tensor_arg)
10697*da0073e9SAndroid Build Coastguard Worker        ndim += extra_dim
10698*da0073e9SAndroid Build Coastguard Worker
10699*da0073e9SAndroid Build Coastguard Worker        n_dim_to_test = sum(e is DIM_ARG for e in arg_constr())
10700*da0073e9SAndroid Build Coastguard Worker
10701*da0073e9SAndroid Build Coastguard Worker        for dims_val in combinations(range(ndim), n_dim_to_test):
10702*da0073e9SAndroid Build Coastguard Worker            arg = arg_constr()
10703*da0073e9SAndroid Build Coastguard Worker            arg_neg = copy.deepcopy(arg)
10704*da0073e9SAndroid Build Coastguard Worker            idx = 0
10705*da0073e9SAndroid Build Coastguard Worker            for i, v in enumerate(arg):
10706*da0073e9SAndroid Build Coastguard Worker                if v is DIM_ARG:
10707*da0073e9SAndroid Build Coastguard Worker                    arg[i] = dims_val[idx]
10708*da0073e9SAndroid Build Coastguard Worker                    arg_neg[i] = dims_val[idx] - ndim
10709*da0073e9SAndroid Build Coastguard Worker                    idx += 1
10710*da0073e9SAndroid Build Coastguard Worker
10711*da0073e9SAndroid Build Coastguard Worker            if METHOD in types:
10712*da0073e9SAndroid Build Coastguard Worker                a = getattr(x, name)(*arg)
10713*da0073e9SAndroid Build Coastguard Worker                b = getattr(x, name)(*arg_neg)
10714*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, b)
10715*da0073e9SAndroid Build Coastguard Worker
10716*da0073e9SAndroid Build Coastguard Worker            if INPLACE_METHOD in types:
10717*da0073e9SAndroid Build Coastguard Worker                a = x.clone()
10718*da0073e9SAndroid Build Coastguard Worker                getattr(a, name + '_')(*arg)
10719*da0073e9SAndroid Build Coastguard Worker                b = x.clone()
10720*da0073e9SAndroid Build Coastguard Worker                getattr(b, name + '_')(*arg_neg)
10721*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, b)
10722*da0073e9SAndroid Build Coastguard Worker
10723*da0073e9SAndroid Build Coastguard Worker            if FUNCTIONAL in types:
10724*da0073e9SAndroid Build Coastguard Worker                a = getattr(torch, name)(x, *arg)
10725*da0073e9SAndroid Build Coastguard Worker                b = getattr(torch, name)(x, *arg_neg)
10726*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, b)
10727*da0073e9SAndroid Build Coastguard Worker
10728*da0073e9SAndroid Build Coastguard Worker    return neg_dim_test
10729*da0073e9SAndroid Build Coastguard Worker
10730*da0073e9SAndroid Build Coastguard Workerdef idx_tensor(size, max_val):
10731*da0073e9SAndroid Build Coastguard Worker    return torch.LongTensor(*size).random_(0, max_val - 1)
10732*da0073e9SAndroid Build Coastguard Worker
10733*da0073e9SAndroid Build Coastguard Workerdef add_neg_dim_tests():
10734*da0073e9SAndroid Build Coastguard Worker    neg_dim_tests = [
10735*da0073e9SAndroid Build Coastguard Worker        ('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]),
10736*da0073e9SAndroid Build Coastguard Worker        ('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10737*da0073e9SAndroid Build Coastguard Worker        ('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]),
10738*da0073e9SAndroid Build Coastguard Worker        ('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
10739*da0073e9SAndroid Build Coastguard Worker        ('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10740*da0073e9SAndroid Build Coastguard Worker        ('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]),
10741*da0073e9SAndroid Build Coastguard Worker        ('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]),
10742*da0073e9SAndroid Build Coastguard Worker        ('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10743*da0073e9SAndroid Build Coastguard Worker        ('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10744*da0073e9SAndroid Build Coastguard Worker        ('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]),
10745*da0073e9SAndroid Build Coastguard Worker        ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1),
10746*da0073e9SAndroid Build Coastguard Worker        ('logcumsumexp', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10747*da0073e9SAndroid Build Coastguard Worker        ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10748*da0073e9SAndroid Build Coastguard Worker        ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10749*da0073e9SAndroid Build Coastguard Worker        ('cummax', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10750*da0073e9SAndroid Build Coastguard Worker        ('cummin', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10751*da0073e9SAndroid Build Coastguard Worker        ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10752*da0073e9SAndroid Build Coastguard Worker        ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10753*da0073e9SAndroid Build Coastguard Worker        ('nanmedian', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10754*da0073e9SAndroid Build Coastguard Worker        ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10755*da0073e9SAndroid Build Coastguard Worker        ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]),
10756*da0073e9SAndroid Build Coastguard Worker        ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10757*da0073e9SAndroid Build Coastguard Worker        ('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10758*da0073e9SAndroid Build Coastguard Worker        ('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10759*da0073e9SAndroid Build Coastguard Worker        ('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10760*da0073e9SAndroid Build Coastguard Worker        ('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]),
10761*da0073e9SAndroid Build Coastguard Worker        ('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10762*da0073e9SAndroid Build Coastguard Worker        ('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10763*da0073e9SAndroid Build Coastguard Worker        ('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10764*da0073e9SAndroid Build Coastguard Worker        ('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10765*da0073e9SAndroid Build Coastguard Worker        ('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10766*da0073e9SAndroid Build Coastguard Worker        ('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10767*da0073e9SAndroid Build Coastguard Worker        ('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10768*da0073e9SAndroid Build Coastguard Worker        ('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]),
10769*da0073e9SAndroid Build Coastguard Worker        ('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10770*da0073e9SAndroid Build Coastguard Worker        ('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]),
10771*da0073e9SAndroid Build Coastguard Worker        ('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]),
10772*da0073e9SAndroid Build Coastguard Worker    ]
10773*da0073e9SAndroid Build Coastguard Worker
10774*da0073e9SAndroid Build Coastguard Worker    for decl in neg_dim_tests:
10775*da0073e9SAndroid Build Coastguard Worker        if len(decl) == 4:
10776*da0073e9SAndroid Build Coastguard Worker            name, tensor_arg, arg_constr, types = decl
10777*da0073e9SAndroid Build Coastguard Worker            extra_dim = 0
10778*da0073e9SAndroid Build Coastguard Worker        elif len(decl) == 5:
10779*da0073e9SAndroid Build Coastguard Worker            name, tensor_arg, arg_constr, types, extra_dim = decl
10780*da0073e9SAndroid Build Coastguard Worker
10781*da0073e9SAndroid Build Coastguard Worker        test_name = 'test_' + name + '_neg_dim'
10782*da0073e9SAndroid Build Coastguard Worker
10783*da0073e9SAndroid Build Coastguard Worker        assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name
10784*da0073e9SAndroid Build Coastguard Worker        setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))
10785*da0073e9SAndroid Build Coastguard Worker
10786*da0073e9SAndroid Build Coastguard Worker# TODO: these empy classes are temporarily instantiated for XLA compatibility
10787*da0073e9SAndroid Build Coastguard Worker#   once XLA updates their test suite it should be removed
10788*da0073e9SAndroid Build Coastguard Workerclass TestViewOps(TestCase):
10789*da0073e9SAndroid Build Coastguard Worker    pass
10790*da0073e9SAndroid Build Coastguard Worker
10791*da0073e9SAndroid Build Coastguard Workerclass TestTensorDeviceOps(TestCase):
10792*da0073e9SAndroid Build Coastguard Worker    pass
10793*da0073e9SAndroid Build Coastguard Worker
10794*da0073e9SAndroid Build Coastguard Worker# Generates tests
10795*da0073e9SAndroid Build Coastguard Worker# Note: test generation must be done at file scope, not within main, or
10796*da0073e9SAndroid Build Coastguard Worker# pytest will fail.
10797*da0073e9SAndroid Build Coastguard Workeradd_neg_dim_tests()
10798*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestViewOps, globals())
10799*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestVitalSignsCuda, globals())
10800*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTensorDeviceOps, globals())
10801*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTorchDeviceType, globals())
10802*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu')
10803*da0073e9SAndroid Build Coastguard Worker
10804*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
10805*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
10806*da0073e9SAndroid Build Coastguard Worker    run_tests()
10807