xref: /aosp_15_r20/external/pytorch/test/test_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport ctypes
5*da0073e9SAndroid Build Coastguard Workerimport gc
6*da0073e9SAndroid Build Coastguard Workerimport json
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport pickle
9*da0073e9SAndroid Build Coastguard Workerimport random
10*da0073e9SAndroid Build Coastguard Workerimport subprocess
11*da0073e9SAndroid Build Coastguard Workerimport sys
12*da0073e9SAndroid Build Coastguard Workerimport tempfile
13*da0073e9SAndroid Build Coastguard Workerimport threading
14*da0073e9SAndroid Build Coastguard Workerimport unittest
15*da0073e9SAndroid Build Coastguard Workerimport warnings
16*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
17*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
18*da0073e9SAndroid Build Coastguard Workerfrom random import randint
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerimport psutil
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerimport torch
23*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
24*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
25*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda._memory_viz import (
26*da0073e9SAndroid Build Coastguard Worker    _profile_to_snapshot,
27*da0073e9SAndroid Build Coastguard Worker    profile_plot,
28*da0073e9SAndroid Build Coastguard Worker    segment_plot,
29*da0073e9SAndroid Build Coastguard Worker    trace_plot,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
33*da0073e9SAndroid Build Coastguard Worker    _create_scaling_case,
34*da0073e9SAndroid Build Coastguard Worker    _get_torch_cuda_version,
35*da0073e9SAndroid Build Coastguard Worker    TEST_CUDNN,
36*da0073e9SAndroid Build Coastguard Worker    TEST_MULTIGPU,
37*da0073e9SAndroid Build Coastguard Worker)
38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
39*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
40*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
41*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_optimizers import (
44*da0073e9SAndroid Build Coastguard Worker    _get_optim_inputs_including_global_cliquey_kwargs,
45*da0073e9SAndroid Build Coastguard Worker    optim_db,
46*da0073e9SAndroid Build Coastguard Worker    optims,
47*da0073e9SAndroid Build Coastguard Worker    TensorTracker,
48*da0073e9SAndroid Build Coastguard Worker)
49*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
50*da0073e9SAndroid Build Coastguard Worker    EXPANDABLE_SEGMENTS,
51*da0073e9SAndroid Build Coastguard Worker    freeze_rng_state,
52*da0073e9SAndroid Build Coastguard Worker    gcIfJetson,
53*da0073e9SAndroid Build Coastguard Worker    get_cycles_per_ms,
54*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
55*da0073e9SAndroid Build Coastguard Worker    IS_ARM64,
56*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
57*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
58*da0073e9SAndroid Build Coastguard Worker    IS_LINUX,
59*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
60*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
61*da0073e9SAndroid Build Coastguard Worker    load_tests,
62*da0073e9SAndroid Build Coastguard Worker    NO_MULTIPROCESSING_SPAWN,
63*da0073e9SAndroid Build Coastguard Worker    parametrize,
64*da0073e9SAndroid Build Coastguard Worker    run_tests,
65*da0073e9SAndroid Build Coastguard Worker    serialTest,
66*da0073e9SAndroid Build Coastguard Worker    skipCUDAMemoryLeakCheckIf,
67*da0073e9SAndroid Build Coastguard Worker    skipCUDANonDefaultStreamIf,
68*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
69*da0073e9SAndroid Build Coastguard Worker    slowTest,
70*da0073e9SAndroid Build Coastguard Worker    subtest,
71*da0073e9SAndroid Build Coastguard Worker    TemporaryFileName,
72*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
73*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA_GRAPH,
74*da0073e9SAndroid Build Coastguard Worker    TEST_NUMPY,
75*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
76*da0073e9SAndroid Build Coastguard Worker    TestCase,
77*da0073e9SAndroid Build Coastguard Worker)
78*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import checkpoint_sequential
79*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.viz._cycles import observe_tensor_cycles
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for
83*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
84*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workertry:
87*da0073e9SAndroid Build Coastguard Worker    import torchvision.models  # noqa: F401
88*da0073e9SAndroid Build Coastguard Worker    from torchvision.models import resnet18  # noqa: F401
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
91*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
92*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
93*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard WorkerTEST_CUDAMALLOCASYNC = TEST_CUDA and (
96*da0073e9SAndroid Build Coastguard Worker    torch.cuda.get_allocator_backend() == "cudaMallocAsync"
97*da0073e9SAndroid Build Coastguard Worker)
98*da0073e9SAndroid Build Coastguard WorkerTEST_LARGE_TENSOR = TEST_CUDA
99*da0073e9SAndroid Build Coastguard WorkerTEST_MEDIUM_TENSOR = TEST_CUDA
100*da0073e9SAndroid Build Coastguard WorkerTEST_BF16 = False
101*da0073e9SAndroid Build Coastguard WorkerTEST_PYNVML = not torch.cuda._HAS_PYNVML
102*da0073e9SAndroid Build Coastguard Workerif TEST_CUDA:
103*da0073e9SAndroid Build Coastguard Worker    TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
104*da0073e9SAndroid Build Coastguard Worker    TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9
105*da0073e9SAndroid Build Coastguard Worker    TEST_BF16 = torch.cuda.is_bf16_supported()
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker_cycles_per_ms = None
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
111*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest
112*da0073e9SAndroid Build Coastguard Workerclass TestCuda(TestCase):
113*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
114*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
115*da0073e9SAndroid Build Coastguard Worker    FIFTY_MIL_CYCLES = 50000000
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
118*da0073e9SAndroid Build Coastguard Worker        super().setUp()
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
121*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    @property
124*da0073e9SAndroid Build Coastguard Worker    def expandable_segments(self):
125*da0073e9SAndroid Build Coastguard Worker        return EXPANDABLE_SEGMENTS
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def test_pinned_memory_with_cudaregister(self):
128*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
129*da0073e9SAndroid Build Coastguard Worker            "pinned_use_cuda_host_register:True,pinned_num_register_threads:8"
130*da0073e9SAndroid Build Coastguard Worker        )
131*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(20)
132*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t.is_pinned())
133*da0073e9SAndroid Build Coastguard Worker        try:
134*da0073e9SAndroid Build Coastguard Worker            pinned_t = torch.ones(1 << 21).pin_memory()
135*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pinned_t.is_pinned())
136*da0073e9SAndroid Build Coastguard Worker            pinned_t = torch.ones(1 << 24).pin_memory()
137*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pinned_t.is_pinned())
138*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
139*da0073e9SAndroid Build Coastguard Worker            # Some GPUs don't support same address space on host and device side
140*da0073e9SAndroid Build Coastguard Worker            pass
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def test_pinned_memory_with_cudaregister_multithread(self):
143*da0073e9SAndroid Build Coastguard Worker        num_threads = 4
144*da0073e9SAndroid Build Coastguard Worker        threads = [
145*da0073e9SAndroid Build Coastguard Worker            threading.Thread(target=self.test_pinned_memory_with_cudaregister)
146*da0073e9SAndroid Build Coastguard Worker            for t in range(num_threads)
147*da0073e9SAndroid Build Coastguard Worker        ]
148*da0073e9SAndroid Build Coastguard Worker        for thread in threads:
149*da0073e9SAndroid Build Coastguard Worker            thread.start()
150*da0073e9SAndroid Build Coastguard Worker        for thread in threads:
151*da0073e9SAndroid Build Coastguard Worker            thread.join()
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    def test_pinned_memory_empty_cache(self):
154*da0073e9SAndroid Build Coastguard Worker        for alloc_settings in (True, False):
155*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings(
156*da0073e9SAndroid Build Coastguard Worker                f"pinned_use_cuda_host_register:{alloc_settings}"
157*da0073e9SAndroid Build Coastguard Worker            )
158*da0073e9SAndroid Build Coastguard Worker            try:
159*da0073e9SAndroid Build Coastguard Worker                t = torch.ones(1024 * 1024, pin_memory=True)
160*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.is_pinned())
161*da0073e9SAndroid Build Coastguard Worker                del t
162*da0073e9SAndroid Build Coastguard Worker                torch._C._host_emptyCache()
163*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
164*da0073e9SAndroid Build Coastguard Worker                # Some GPUs don't support same address space on host and device side
165*da0073e9SAndroid Build Coastguard Worker                pass
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def test_cudart_register(self):
168*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(20)
169*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t.is_pinned())
170*da0073e9SAndroid Build Coastguard Worker        cudart = torch.cuda.cudart()
171*da0073e9SAndroid Build Coastguard Worker        r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 0)
173*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.is_pinned())
174*da0073e9SAndroid Build Coastguard Worker        r = cudart.cudaHostUnregister(t.data_ptr())
175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 0)
176*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t.is_pinned())
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def test_memory_allocation(self):
179*da0073e9SAndroid Build Coastguard Worker        gc.collect()
180*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
181*da0073e9SAndroid Build Coastguard Worker        mem = None
182*da0073e9SAndroid Build Coastguard Worker        size = 1
183*da0073e9SAndroid Build Coastguard Worker        prev = 0
184*da0073e9SAndroid Build Coastguard Worker        try:
185*da0073e9SAndroid Build Coastguard Worker            prev = torch.cuda.memory_allocated()
186*da0073e9SAndroid Build Coastguard Worker            mem = torch.cuda.caching_allocator_alloc(size)
187*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(torch.cuda.memory_allocated(), prev)
188*da0073e9SAndroid Build Coastguard Worker        finally:
189*da0073e9SAndroid Build Coastguard Worker            if mem is not None:
190*da0073e9SAndroid Build Coastguard Worker                torch.cuda.caching_allocator_delete(mem)
191*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.memory_allocated(), prev)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def test_check_error(self):
194*da0073e9SAndroid Build Coastguard Worker        # Assert this call doesn't raise.
195*da0073e9SAndroid Build Coastguard Worker        torch.cuda.check_error(0)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
198*da0073e9SAndroid Build Coastguard Worker            torch.cuda.CudaError, "out of memory|hipErrorOutOfMemory"
199*da0073e9SAndroid Build Coastguard Worker        ):
200*da0073e9SAndroid Build Coastguard Worker            torch.cuda.check_error(2)
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    def test_cuda_get_device_name(self):
203*da0073e9SAndroid Build Coastguard Worker        # Testing the behaviour with None as an argument
204*da0073e9SAndroid Build Coastguard Worker        current_device = torch.cuda.current_device()
205*da0073e9SAndroid Build Coastguard Worker        current_device_name = torch.cuda.get_device_name(current_device)
206*da0073e9SAndroid Build Coastguard Worker        device_name_None = torch.cuda.get_device_name(None)
207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device_name, device_name_None)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        # Testing the behaviour for No argument
210*da0073e9SAndroid Build Coastguard Worker        device_name_no_argument = torch.cuda.get_device_name()
211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device_name, device_name_no_argument)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    def test_cuda_get_device_capability(self):
214*da0073e9SAndroid Build Coastguard Worker        # Testing the behaviour with None as an argument
215*da0073e9SAndroid Build Coastguard Worker        current_device = torch.cuda.current_device()
216*da0073e9SAndroid Build Coastguard Worker        current_device_capability = torch.cuda.get_device_capability(current_device)
217*da0073e9SAndroid Build Coastguard Worker        device_capability_None = torch.cuda.get_device_capability(None)
218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device_capability, device_capability_None)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        # Testing the behaviour for No argument
221*da0073e9SAndroid Build Coastguard Worker        device_capability_no_argument = torch.cuda.get_device_capability()
222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_device_capability, device_capability_no_argument)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def test_out_of_memory(self):
225*da0073e9SAndroid Build Coastguard Worker        tensor = torch.zeros(1024, device="cuda")
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        oom_regex = (
228*da0073e9SAndroid Build Coastguard Worker            "would exceed allowed memory"
229*da0073e9SAndroid Build Coastguard Worker            if TEST_CUDAMALLOCASYNC
230*da0073e9SAndroid Build Coastguard Worker            else "Tried to allocate 800000000.00 GiB"
231*da0073e9SAndroid Build Coastguard Worker        )
232*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, oom_regex):
233*da0073e9SAndroid Build Coastguard Worker            torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="cuda")
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
236*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Tried to allocate more than 1EB memory"
237*da0073e9SAndroid Build Coastguard Worker        ):
238*da0073e9SAndroid Build Coastguard Worker            torch.empty(
239*da0073e9SAndroid Build Coastguard Worker                1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="cuda"
240*da0073e9SAndroid Build Coastguard Worker            )
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        # ensure out of memory error doesn't disturb subsequent kernel
243*da0073e9SAndroid Build Coastguard Worker        tensor.fill_(1)
244*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((tensor == 1).all())
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
247*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)"
248*da0073e9SAndroid Build Coastguard Worker    )
249*da0073e9SAndroid Build Coastguard Worker    @serialTest()
250*da0073e9SAndroid Build Coastguard Worker    def test_out_of_memory_retry(self):
251*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
252*da0073e9SAndroid Build Coastguard Worker        total_memory = torch.cuda.get_device_properties(0).total_memory
253*da0073e9SAndroid Build Coastguard Worker        oom_regex = (
254*da0073e9SAndroid Build Coastguard Worker            "would exceed allowed memory"
255*da0073e9SAndroid Build Coastguard Worker            if TEST_CUDAMALLOCASYNC
256*da0073e9SAndroid Build Coastguard Worker            else "Tried to allocate"
257*da0073e9SAndroid Build Coastguard Worker        )
258*da0073e9SAndroid Build Coastguard Worker        size = int(total_memory * 0.5)
259*da0073e9SAndroid Build Coastguard Worker        a = torch.empty(size, dtype=torch.int8, device="cuda")
260*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, oom_regex):
261*da0073e9SAndroid Build Coastguard Worker            b = torch.empty(size, dtype=torch.int8, device="cuda")
262*da0073e9SAndroid Build Coastguard Worker        del a
263*da0073e9SAndroid Build Coastguard Worker        b = torch.empty(size, dtype=torch.int8, device="cuda")
264*da0073e9SAndroid Build Coastguard Worker        del b
265*da0073e9SAndroid Build Coastguard Worker        # We used a lot of memory here, clean up so we don't affect other tests too much
266*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
267*da0073e9SAndroid Build Coastguard Worker        torch.cuda.reset_peak_memory_stats()
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    @serialTest()
270*da0073e9SAndroid Build Coastguard Worker    def test_set_per_process_memory_fraction(self):
271*da0073e9SAndroid Build Coastguard Worker        # test invalid fraction value.
272*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "Invalid type"):
273*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_per_process_memory_fraction(1)
274*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
275*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_per_process_memory_fraction(-0.1)
276*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
277*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_per_process_memory_fraction(2.0)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        tensor = torch.zeros(1024, device="cuda")
280*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
281*da0073e9SAndroid Build Coastguard Worker        total_memory = torch.cuda.get_device_properties(0).total_memory
282*da0073e9SAndroid Build Coastguard Worker        torch.cuda.set_per_process_memory_fraction(0.5, 0)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        # test 0.499 allocation is ok.
285*da0073e9SAndroid Build Coastguard Worker        application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved()
286*da0073e9SAndroid Build Coastguard Worker        tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda")
287*da0073e9SAndroid Build Coastguard Worker        del tmp_tensor
288*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        application = int(total_memory * 0.5)
291*da0073e9SAndroid Build Coastguard Worker        # it will get OOM when try to allocate more than half memory.
292*da0073e9SAndroid Build Coastguard Worker        oom_regex = (
293*da0073e9SAndroid Build Coastguard Worker            "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
294*da0073e9SAndroid Build Coastguard Worker        )
295*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, oom_regex):
296*da0073e9SAndroid Build Coastguard Worker            torch.empty(application, dtype=torch.int8, device="cuda")
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        # ensure out of memory error doesn't disturb subsequent kernel
299*da0073e9SAndroid Build Coastguard Worker        tensor.fill_(1)
300*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((tensor == 1).all())
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "uuid attribute not yet available")
303*da0073e9SAndroid Build Coastguard Worker    def test_uuid(self):
304*da0073e9SAndroid Build Coastguard Worker        uuid = torch.cuda.get_device_properties(0).uuid
305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(str(uuid)), 36)  # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
306*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(uuid.bytes), 16)
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    def test_copy_non_blocking(self):
309*da0073e9SAndroid Build Coastguard Worker        def _test_copy_non_blocking(a, b):
310*da0073e9SAndroid Build Coastguard Worker            event = torch.cuda.Event()
311*da0073e9SAndroid Build Coastguard Worker            a.copy_(b, non_blocking=True)
312*da0073e9SAndroid Build Coastguard Worker            event.record()
313*da0073e9SAndroid Build Coastguard Worker            event.synchronize()
314*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        # 10MB copies
317*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(10000000, dtype=torch.uint8).cuda()
318*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
319*da0073e9SAndroid Build Coastguard Worker        _test_copy_non_blocking(x, y)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
322*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(10000000, dtype=torch.uint8).cuda()
323*da0073e9SAndroid Build Coastguard Worker        _test_copy_non_blocking(x, y)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        # Test the case where the pinned data_ptr is not equal to the storage data_ptr.
326*da0073e9SAndroid Build Coastguard Worker        x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
327*da0073e9SAndroid Build Coastguard Worker        x = x_base[1:]
328*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.is_pinned())
329*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x_base.is_pinned())
330*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(x_base.data_ptr(), x.data_ptr())
331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr())
332*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda()
333*da0073e9SAndroid Build Coastguard Worker        _test_copy_non_blocking(x, y)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    def test_copy_non_blocking_type_conversion(self):
336*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, device="cuda")
337*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1, device="cpu", pin_memory=True)
338*da0073e9SAndroid Build Coastguard Worker        c = torch.empty(1, device="cuda", dtype=torch.long)
339*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
340*da0073e9SAndroid Build Coastguard Worker        b.copy_(a, non_blocking=True)
341*da0073e9SAndroid Build Coastguard Worker        c.copy_(b, non_blocking=True)
342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, c, exact_dtype=False)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    @serialTest()
345*da0073e9SAndroid Build Coastguard Worker    def test_to_non_blocking(self):
346*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.current_stream()
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker        def _test_to_non_blocking(a, non_blocking, dst):
349*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
350*da0073e9SAndroid Build Coastguard Worker            # Pushes an 0.1 second spin to stream so if the copy is non blocking,
351*da0073e9SAndroid Build Coastguard Worker            # stream will almost surely be active when we query().
352*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(int(100 * get_cycles_per_ms()))
353*da0073e9SAndroid Build Coastguard Worker            b = a.to(device=dst, non_blocking=non_blocking)
354*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stream.query(), not non_blocking)
355*da0073e9SAndroid Build Coastguard Worker            stream.synchronize()
356*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
357*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu"))
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)):
360*da0073e9SAndroid Build Coastguard Worker            # Creates source on the opposite device from destination.
361*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(
362*da0073e9SAndroid Build Coastguard Worker                1000000,
363*da0073e9SAndroid Build Coastguard Worker                device="cuda" if dst == "cpu" else "cpu",
364*da0073e9SAndroid Build Coastguard Worker                pin_memory=True if dst == "cuda" else False,
365*da0073e9SAndroid Build Coastguard Worker            )
366*da0073e9SAndroid Build Coastguard Worker            _test_to_non_blocking(src, try_non_blocking, dst)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    def test_to_cpu_blocking_by_default(self):
369*da0073e9SAndroid Build Coastguard Worker        src = torch.randn(1000000, device="cuda")
370*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
371*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
372*da0073e9SAndroid Build Coastguard Worker        dst = src.to(device="cpu")
373*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.current_stream().query(), True)
374*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src, dst)
375*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(dst.is_pinned())
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker    def test_serialization_array_with_storage(self):
378*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5).cuda()
379*da0073e9SAndroid Build Coastguard Worker        y = torch.IntTensor(2, 5).fill_(0).cuda()
380*da0073e9SAndroid Build Coastguard Worker        q = [x, y, x, y.storage()]
381*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
382*da0073e9SAndroid Build Coastguard Worker            torch.save(q, f)
383*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
384*da0073e9SAndroid Build Coastguard Worker            q_copy = torch.load(f)
385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy, q, atol=0, rtol=0)
386*da0073e9SAndroid Build Coastguard Worker        q_copy[0].fill_(5)
387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
388*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
389*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
390*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
391*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
392*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
393*da0073e9SAndroid Build Coastguard Worker        q_copy[1].fill_(10)
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
397*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
398*da0073e9SAndroid Build Coastguard Worker    )
399*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
400*da0073e9SAndroid Build Coastguard Worker        _get_torch_cuda_version() >= (12, 2),
401*da0073e9SAndroid Build Coastguard Worker        "skipped as explicit workspace allocation is removed",
402*da0073e9SAndroid Build Coastguard Worker    )
403*da0073e9SAndroid Build Coastguard Worker    def test_cublas_workspace_explicit_allocation(self):
404*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(7, 7, device="cuda", requires_grad=False)
405*da0073e9SAndroid Build Coastguard Worker        default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024  # :4096:2:16:8
406*da0073e9SAndroid Build Coastguard Worker        # different size (32 MiB) expected on Hopper GPU
407*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.get_device_capability() == (9, 0):
408*da0073e9SAndroid Build Coastguard Worker            default_workspace_size = 4096 * 8 * 1024
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        def check_workspace_size(inp):
411*da0073e9SAndroid Build Coastguard Worker            torch._C._cuda_clearCublasWorkspaces()
412*da0073e9SAndroid Build Coastguard Worker            start = torch.cuda.memory_stats()["active_bytes.all.allocated"]
413*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
414*da0073e9SAndroid Build Coastguard Worker                torch.matmul(inp, inp)
415*da0073e9SAndroid Build Coastguard Worker            finish = torch.cuda.memory_stats()["active_bytes.all.allocated"]
416*da0073e9SAndroid Build Coastguard Worker            return finish - start
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker        # check default
419*da0073e9SAndroid Build Coastguard Worker        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
420*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # check default with bad user config
423*da0073e9SAndroid Build Coastguard Worker        os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1"
424*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        # check valid config
427*da0073e9SAndroid Build Coastguard Worker        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32"
428*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(abs(check_workspace_size(a) - (3072 * 1024)) < 524288)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        torch._C._cuda_clearCublasWorkspaces()
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def test_cublas_allow_tf32_get_set(self):
433*da0073e9SAndroid Build Coastguard Worker        skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
434*da0073e9SAndroid Build Coastguard Worker            os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
435*da0073e9SAndroid Build Coastguard Worker        )
436*da0073e9SAndroid Build Coastguard Worker        if skip_tf32_cublas:
437*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
438*da0073e9SAndroid Build Coastguard Worker            return
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker        orig = torch.backends.cuda.matmul.allow_tf32
441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
442*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = not orig
443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
444*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = orig
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker    def test_float32_matmul_precision_get_set(self):
447*da0073e9SAndroid Build Coastguard Worker        orig = torch.get_float32_matmul_precision()
448*da0073e9SAndroid Build Coastguard Worker        skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
449*da0073e9SAndroid Build Coastguard Worker            os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
450*da0073e9SAndroid Build Coastguard Worker        )
451*da0073e9SAndroid Build Coastguard Worker        # this is really just checking that the environment variable is respected during testing
452*da0073e9SAndroid Build Coastguard Worker        # and not overwritten by another function that doesn't revert it to the intitial value
453*da0073e9SAndroid Build Coastguard Worker        if not skip_tf32_cublas:
454*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
455*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.get_float32_matmul_precision(), "highest")
456*da0073e9SAndroid Build Coastguard Worker        else:
457*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
458*da0073e9SAndroid Build Coastguard Worker        for p in ("medium", "high"):
459*da0073e9SAndroid Build Coastguard Worker            torch.set_float32_matmul_precision(p)
460*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.get_float32_matmul_precision(), p)
461*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
462*da0073e9SAndroid Build Coastguard Worker        torch.set_float32_matmul_precision("highest")
463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.get_float32_matmul_precision(), "highest")
464*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
465*da0073e9SAndroid Build Coastguard Worker        torch.set_float32_matmul_precision(orig)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
468*da0073e9SAndroid Build Coastguard Worker        orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
470*da0073e9SAndroid Build Coastguard Worker            torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig
471*da0073e9SAndroid Build Coastguard Worker        )
472*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
474*da0073e9SAndroid Build Coastguard Worker            torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig
475*da0073e9SAndroid Build Coastguard Worker        )
476*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker    def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
479*da0073e9SAndroid Build Coastguard Worker        orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
481*da0073e9SAndroid Build Coastguard Worker            torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig
482*da0073e9SAndroid Build Coastguard Worker        )
483*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
484*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
485*da0073e9SAndroid Build Coastguard Worker            torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig
486*da0073e9SAndroid Build Coastguard Worker        )
487*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_allow_tf32_get_set(self):
490*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(
491*da0073e9SAndroid Build Coastguard Worker            enabled=None, benchmark=None, deterministic=None, allow_tf32=False
492*da0073e9SAndroid Build Coastguard Worker        ):
493*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.backends.cudnn.allow_tf32)
494*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(
495*da0073e9SAndroid Build Coastguard Worker            enabled=None, benchmark=None, deterministic=None, allow_tf32=True
496*da0073e9SAndroid Build Coastguard Worker        ):
497*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.backends.cudnn.allow_tf32)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker    def test_type_conversions(self):
500*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
501*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.float(), torch.FloatTensor)
502*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor)
503*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor)
504*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor)
505*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        y = x.storage()
508*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y.float(), torch.FloatStorage)
509*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage)
510*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage)
511*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage)
512*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage)
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("was disabled due to not enough memory, but actually it always fail")
515*da0073e9SAndroid Build Coastguard Worker    def test_arithmetic_large_tensor(self):
516*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2**30, device="cuda")
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        x.fill_(1)
519*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 2**30)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        x += 1
522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 2**31)
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker        x.fill_(1)
525*da0073e9SAndroid Build Coastguard Worker        x -= 0.5
526*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 2**29)
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker        x.fill_(1)
529*da0073e9SAndroid Build Coastguard Worker        x *= 2
530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 2**31)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker        x.fill_(1)
533*da0073e9SAndroid Build Coastguard Worker        x /= 2
534*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 2**29)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    def test_gather_bool(self):
537*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[False, True], [True, True]], device="cuda")
538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
539*da0073e9SAndroid Build Coastguard Worker            torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device="cuda")),
540*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[False, False], [True, True]], device="cuda"),
541*da0073e9SAndroid Build Coastguard Worker        )
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker    def test_torch_manual_seed_seeds_cuda_devices(self):
544*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
545*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(4, 4).float().cuda()
546*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(2)
547*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.initial_seed(), 2)
548*da0073e9SAndroid Build Coastguard Worker            x.uniform_()
549*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(2)
550*da0073e9SAndroid Build Coastguard Worker            y = x.clone().uniform_()
551*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
552*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.initial_seed(), 2)
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    def test_manual_seed(self):
555*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
556*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(4, 4).float().cuda()
557*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(2)
558*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.initial_seed(), 2)
559*da0073e9SAndroid Build Coastguard Worker            x.uniform_()
560*da0073e9SAndroid Build Coastguard Worker            a = torch.bernoulli(torch.full_like(x, 0.5))
561*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(2)
562*da0073e9SAndroid Build Coastguard Worker            y = x.clone().uniform_()
563*da0073e9SAndroid Build Coastguard Worker            b = torch.bernoulli(torch.full_like(x, 0.5))
564*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
565*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
566*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.initial_seed(), 2)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    def test_specify_improper_device_name(self):
569*da0073e9SAndroid Build Coastguard Worker        import os
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        fname = "tempfile.pt"
572*da0073e9SAndroid Build Coastguard Worker        try:
573*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
574*da0073e9SAndroid Build Coastguard Worker                torch.save(
575*da0073e9SAndroid Build Coastguard Worker                    [torch.nn.Parameter(torch.randn(10, 10))],
576*da0073e9SAndroid Build Coastguard Worker                    fname,
577*da0073e9SAndroid Build Coastguard Worker                    _use_new_zipfile_serialization=True,
578*da0073e9SAndroid Build Coastguard Worker                )
579*da0073e9SAndroid Build Coastguard Worker                torch.load(fname, "cuda0")
580*da0073e9SAndroid Build Coastguard Worker        finally:
581*da0073e9SAndroid Build Coastguard Worker            if os.path.exists(fname):
582*da0073e9SAndroid Build Coastguard Worker                os.remove(fname)
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    def test_get_device_index(self):
585*da0073e9SAndroid Build Coastguard Worker        from torch.cuda._utils import _get_device_index
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
588*da0073e9SAndroid Build Coastguard Worker            _get_device_index("cuda0", optional=True)
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a cuda device"):
591*da0073e9SAndroid Build Coastguard Worker            cpu_device = torch.device("cpu")
592*da0073e9SAndroid Build Coastguard Worker            _get_device_index(cpu_device, optional=True)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    def test_serialization_array_with_empty(self):
595*da0073e9SAndroid Build Coastguard Worker        x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()]
596*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
597*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
598*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
599*da0073e9SAndroid Build Coastguard Worker            x_copy = torch.load(f)
600*da0073e9SAndroid Build Coastguard Worker        for original, copy in zip(x, x_copy):
601*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy, original)
602*da0073e9SAndroid Build Coastguard Worker            self.assertIs(type(copy), type(original))
603*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy.get_device(), original.get_device())
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    @skipCUDANonDefaultStreamIf(True)
606*da0073e9SAndroid Build Coastguard Worker    def test_streams(self):
607*da0073e9SAndroid Build Coastguard Worker        default_stream = torch.cuda.current_stream()
608*da0073e9SAndroid Build Coastguard Worker        user_stream = torch.cuda.Stream()
609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.current_stream(), default_stream)
610*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(default_stream, user_stream)
611*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(default_stream.cuda_stream, 0)
612*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(user_stream.cuda_stream, 0)
613*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(user_stream):
614*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream(), user_stream)
615*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(user_stream.query())
616*da0073e9SAndroid Build Coastguard Worker        tensor1 = torch.ByteTensor(5).pin_memory()
617*da0073e9SAndroid Build Coastguard Worker        tensor2 = tensor1.cuda(non_blocking=True) + 1
618*da0073e9SAndroid Build Coastguard Worker        default_stream.synchronize()
619*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(default_stream.query())
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker    def test_stream_event_repr(self):
622*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.current_stream()
623*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("torch.cuda.Stream" in s.__repr__())
624*da0073e9SAndroid Build Coastguard Worker        e = torch.cuda.Event()
625*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("torch.cuda.Event" in e.__repr__())
626*da0073e9SAndroid Build Coastguard Worker        s.record_event(e)
627*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("torch.cuda.Event" in e.__repr__())
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker    def test_events(self):
630*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.current_stream()
631*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event(enable_timing=True)
632*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event.query())
633*da0073e9SAndroid Build Coastguard Worker        start_event = torch.cuda.Event(enable_timing=True)
634*da0073e9SAndroid Build Coastguard Worker        stream.record_event(start_event)
635*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(int(50 * get_cycles_per_ms()))
636*da0073e9SAndroid Build Coastguard Worker        stream.record_event(event)
637*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(event.query())
638*da0073e9SAndroid Build Coastguard Worker        event.synchronize()
639*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event.query())
640*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(start_event.elapsed_time(event), 0)
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker    def test_generic_stream_event(self):
643*da0073e9SAndroid Build Coastguard Worker        stream = torch.Stream("cuda")
644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stream.device_index, torch.cuda.current_device())
645*da0073e9SAndroid Build Coastguard Worker        cuda_stream = torch.cuda.Stream(
646*da0073e9SAndroid Build Coastguard Worker            stream_id=stream.stream_id,
647*da0073e9SAndroid Build Coastguard Worker            device_index=stream.device_index,
648*da0073e9SAndroid Build Coastguard Worker            device_type=stream.device_type,
649*da0073e9SAndroid Build Coastguard Worker        )
650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stream.stream_id, cuda_stream.stream_id)
651*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        event1 = torch.Event("cuda", enable_timing=True)
654*da0073e9SAndroid Build Coastguard Worker        event2 = torch.Event("cuda", enable_timing=True)
655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(event1.event_id, 0)
656*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1000)
657*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1000)
658*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(cuda_stream):
659*da0073e9SAndroid Build Coastguard Worker            a_cuda = a.to("cuda", non_blocking=True)
660*da0073e9SAndroid Build Coastguard Worker            b_cuda = b.to("cuda", non_blocking=True)
661*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
662*da0073e9SAndroid Build Coastguard Worker        event1.record(stream)
663*da0073e9SAndroid Build Coastguard Worker        event1.synchronize()
664*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event1.query())
665*da0073e9SAndroid Build Coastguard Worker        c_cuda = a_cuda + b_cuda
666*da0073e9SAndroid Build Coastguard Worker        event2.record()
667*da0073e9SAndroid Build Coastguard Worker        event2.synchronize()
668*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event2.query())
669*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(event1.event_id, event2.event_id)
670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c_cuda.cpu(), a + b)
671*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event1.elapsed_time(event2) > 0)
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker    def test_record_stream(self):
674*da0073e9SAndroid Build Coastguard Worker        cycles_per_ms = get_cycles_per_ms()
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        t = torch.FloatTensor([1, 2, 3, 4]).pin_memory()
677*da0073e9SAndroid Build Coastguard Worker        result = torch.cuda.FloatTensor(t.size())
678*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
679*da0073e9SAndroid Build Coastguard Worker        ptr = [None]
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        # Performs the CPU->GPU copy in a background stream
682*da0073e9SAndroid Build Coastguard Worker        def perform_copy():
683*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(stream):
684*da0073e9SAndroid Build Coastguard Worker                tmp = t.cuda(non_blocking=True)
685*da0073e9SAndroid Build Coastguard Worker                ptr[0] = tmp.data_ptr()
686*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(stream)
687*da0073e9SAndroid Build Coastguard Worker            tmp.record_stream(torch.cuda.current_stream())
688*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the copy
689*da0073e9SAndroid Build Coastguard Worker            result.copy_(tmp)
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker        perform_copy()
692*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
693*da0073e9SAndroid Build Coastguard Worker            tmp2 = torch.cuda.FloatTensor(t.size())
694*da0073e9SAndroid Build Coastguard Worker            tmp2.zero_()
695*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(
696*da0073e9SAndroid Build Coastguard Worker                tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon"
697*da0073e9SAndroid Build Coastguard Worker            )
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.tolist(), [1, 2, 3, 4])
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        if not TEST_CUDAMALLOCASYNC:
702*da0073e9SAndroid Build Coastguard Worker            # In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused
703*da0073e9SAndroid Build Coastguard Worker            # in that side stream after result.copy_(tmp) in the main stream finishes.
704*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().synchronize()
705*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(stream):
706*da0073e9SAndroid Build Coastguard Worker                tmp3 = torch.cuda.FloatTensor(t.size())
707*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used")
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker    def test_record_stream_on_shifted_view(self):
710*da0073e9SAndroid Build Coastguard Worker        # See issue #27366
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker        # This test detects unexpected block reallocation. For reliable test,
713*da0073e9SAndroid Build Coastguard Worker        # the stream to allocate tensors is isolated. The allocator will not
714*da0073e9SAndroid Build Coastguard Worker        # reuse free blocks which were allocated from another stream.
715*da0073e9SAndroid Build Coastguard Worker        stream_alloc = torch.cuda.Stream()
716*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream_alloc):
717*da0073e9SAndroid Build Coastguard Worker            base = torch.cuda.FloatTensor([10, 10])
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        # Record another stream on a shifted view tensor.
720*da0073e9SAndroid Build Coastguard Worker        view = base[5:]
721*da0073e9SAndroid Build Coastguard Worker        assert view.storage_offset() > 0
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        stream_record = torch.cuda.Stream()
724*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream_record):
725*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(int(50 * get_cycles_per_ms()))
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker        view.record_stream(stream_record)
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        # Delete those tensors to make the block free soon.
730*da0073e9SAndroid Build Coastguard Worker        data_ptr = base.data_ptr()
731*da0073e9SAndroid Build Coastguard Worker        del base, view
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker        # A new tensor should not be allocated to the block above.
734*da0073e9SAndroid Build Coastguard Worker        stream_alloc.synchronize()
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream_alloc):
737*da0073e9SAndroid Build Coastguard Worker            try_realloc = torch.cuda.FloatTensor([10, 10])
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_pinned_memory(self):
742*da0073e9SAndroid Build Coastguard Worker        # See issue #3266
743*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 10).view((2, 5))
744*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.t(), x.t().pin_memory())
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker    def test_caching_pinned_memory(self):
747*da0073e9SAndroid Build Coastguard Worker        cycles_per_ms = get_cycles_per_ms()
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        # check that allocations are re-used after deletion
750*da0073e9SAndroid Build Coastguard Worker        t = torch.FloatTensor([1]).pin_memory()
751*da0073e9SAndroid Build Coastguard Worker        ptr = t.data_ptr()
752*da0073e9SAndroid Build Coastguard Worker        del t
753*da0073e9SAndroid Build Coastguard Worker        t = torch.FloatTensor([1]).pin_memory()
754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused")
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker        # check that the allocation is not re-used if it's in-use by a copy
757*da0073e9SAndroid Build Coastguard Worker        gpu_tensor = torch.cuda.FloatTensor([0])
758*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(int(1000 * cycles_per_ms))  # delay the copy by 1s
759*da0073e9SAndroid Build Coastguard Worker        gpu_tensor.copy_(t, non_blocking=True)
760*da0073e9SAndroid Build Coastguard Worker        del t
761*da0073e9SAndroid Build Coastguard Worker        t = torch.FloatTensor([1]).pin_memory()
762*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon")
763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(gpu_tensor), [1])
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker    def test_caching_allocator_record_stream_oom(self):
766*da0073e9SAndroid Build Coastguard Worker        """allocations delayed by a record_stream call should still be freed on
767*da0073e9SAndroid Build Coastguard Worker        an out-of-memory in cuda_malloc_retry. see issue #19219"""
768*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
771*da0073e9SAndroid Build Coastguard Worker            y = torch.zeros(40 * 1024 * 1024, device="cuda")
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        for _ in range(100):
774*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(40 * 1024 * 1024, device="cuda")
775*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(stream):
776*da0073e9SAndroid Build Coastguard Worker                y += x
777*da0073e9SAndroid Build Coastguard Worker            # delays re-use of `x` until after all operations in `stream`
778*da0073e9SAndroid Build Coastguard Worker            x.record_stream(stream)
779*da0073e9SAndroid Build Coastguard Worker            del x
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        # we've made a mess by allocating up to the device capacity. free any
782*da0073e9SAndroid Build Coastguard Worker        # cached blocks in case it affects future tests.
783*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    # Tests for historic illegal memory access, see #17040.
786*da0073e9SAndroid Build Coastguard Worker    def test_reduction_gpu_memory_accessing(self):
787*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(512, 8, dtype=torch.float32, device="cuda")
788*da0073e9SAndroid Build Coastguard Worker        torch.sum(x, 0)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker    def test_sum_fp16(self):
791*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(10, device="cuda", dtype=torch.float16)
792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 0)
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(65504, device="cuda", dtype=torch.float16)
795*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(), 65504)
796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(dtype=torch.float32), 65504)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(65536, device="cuda", dtype=torch.float16)
799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(dtype=torch.float32), 65536)
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(1203611).bernoulli_(0.0005)
802*da0073e9SAndroid Build Coastguard Worker        x = a.to(device="cuda", dtype=torch.float16)
803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum().item(), a.sum().item())
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(100, 121, 80).bernoulli_(0.0005)
806*da0073e9SAndroid Build Coastguard Worker        x = a.to(device="cuda", dtype=torch.float16)
807*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2)))
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    def test_mean_fp16(self):
810*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(65536, device="cuda", dtype=torch.float16)
811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.mean(), 1)
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(65536, device="cuda", dtype=torch.float16)
814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.mean(dtype=torch.float32), 1)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def test_prod_large(self):
817*da0073e9SAndroid Build Coastguard Worker        # tests global reduction (should_global_reduce = true) in case of non-zero identity element
818*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(240000, device="cuda", dtype=torch.float32)
819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.prod(), 1)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        # test for complex types. Note 240k is divisible by 4
822*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.cfloat, torch.cdouble]:
823*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(240000, device="cuda", dtype=dtype) * (0 + 1j)
824*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.prod(), 1)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_ext(self):
827*da0073e9SAndroid Build Coastguard Worker        # Test two corner cases from older PyTorch (Issue #4858)
828*da0073e9SAndroid Build Coastguard Worker        freqs = torch.cuda.FloatTensor(
829*da0073e9SAndroid Build Coastguard Worker            [
830*da0073e9SAndroid Build Coastguard Worker                0.0,
831*da0073e9SAndroid Build Coastguard Worker                0.0,
832*da0073e9SAndroid Build Coastguard Worker                0.0,
833*da0073e9SAndroid Build Coastguard Worker                0.0,
834*da0073e9SAndroid Build Coastguard Worker                0.0,
835*da0073e9SAndroid Build Coastguard Worker                0.0,
836*da0073e9SAndroid Build Coastguard Worker                0.0,
837*da0073e9SAndroid Build Coastguard Worker                0.0,
838*da0073e9SAndroid Build Coastguard Worker                0.0,
839*da0073e9SAndroid Build Coastguard Worker                0.03178183361887932,
840*da0073e9SAndroid Build Coastguard Worker                0.027680952101945877,
841*da0073e9SAndroid Build Coastguard Worker                0.033176131546497345,
842*da0073e9SAndroid Build Coastguard Worker                0.046052902936935425,
843*da0073e9SAndroid Build Coastguard Worker                0.07742464542388916,
844*da0073e9SAndroid Build Coastguard Worker                0.11543981730937958,
845*da0073e9SAndroid Build Coastguard Worker                0.14148041605949402,
846*da0073e9SAndroid Build Coastguard Worker                0.15784293413162231,
847*da0073e9SAndroid Build Coastguard Worker                0.13180233538150787,
848*da0073e9SAndroid Build Coastguard Worker                0.08271478116512299,
849*da0073e9SAndroid Build Coastguard Worker                0.049702685326337814,
850*da0073e9SAndroid Build Coastguard Worker                0.027557924389839172,
851*da0073e9SAndroid Build Coastguard Worker                0.018125897273421288,
852*da0073e9SAndroid Build Coastguard Worker                0.011851548217236996,
853*da0073e9SAndroid Build Coastguard Worker                0.010252203792333603,
854*da0073e9SAndroid Build Coastguard Worker                0.007422595750540495,
855*da0073e9SAndroid Build Coastguard Worker                0.005372154992073774,
856*da0073e9SAndroid Build Coastguard Worker                0.0045109698548913,
857*da0073e9SAndroid Build Coastguard Worker                0.0036087757907807827,
858*da0073e9SAndroid Build Coastguard Worker                0.0035267581697553396,
859*da0073e9SAndroid Build Coastguard Worker                0.0018864056328311563,
860*da0073e9SAndroid Build Coastguard Worker                0.0024605290964245796,
861*da0073e9SAndroid Build Coastguard Worker                0.0022964938543736935,
862*da0073e9SAndroid Build Coastguard Worker                0.0018453967059031129,
863*da0073e9SAndroid Build Coastguard Worker                0.0010662291897460818,
864*da0073e9SAndroid Build Coastguard Worker                0.0009842115687206388,
865*da0073e9SAndroid Build Coastguard Worker                0.00045109697384759784,
866*da0073e9SAndroid Build Coastguard Worker                0.0007791675161570311,
867*da0073e9SAndroid Build Coastguard Worker                0.00020504408166743815,
868*da0073e9SAndroid Build Coastguard Worker                0.00020504408166743815,
869*da0073e9SAndroid Build Coastguard Worker                0.00020504408166743815,
870*da0073e9SAndroid Build Coastguard Worker                0.00012302644609007984,
871*da0073e9SAndroid Build Coastguard Worker                0.0,
872*da0073e9SAndroid Build Coastguard Worker                0.00012302644609007984,
873*da0073e9SAndroid Build Coastguard Worker                4.100881778867915e-05,
874*da0073e9SAndroid Build Coastguard Worker                0.0,
875*da0073e9SAndroid Build Coastguard Worker                0.0,
876*da0073e9SAndroid Build Coastguard Worker                0.0,
877*da0073e9SAndroid Build Coastguard Worker                0.0,
878*da0073e9SAndroid Build Coastguard Worker                0.0,
879*da0073e9SAndroid Build Coastguard Worker                0.0,
880*da0073e9SAndroid Build Coastguard Worker            ]
881*da0073e9SAndroid Build Coastguard Worker        )
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed(11042)
884*da0073e9SAndroid Build Coastguard Worker        sample = torch.multinomial(freqs, 1000, True)
885*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(freqs[sample].min(), 0)
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker        p = torch.zeros(3421, 2, device="cuda", dtype=torch.float)
888*da0073e9SAndroid Build Coastguard Worker        p[:, 1] = 1
889*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed(5214)
890*da0073e9SAndroid Build Coastguard Worker        r = torch.multinomial(p, 1)
891*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(r.min().item(), 0)
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker        # test corner case from Issue #13867
894*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed(33)
895*da0073e9SAndroid Build Coastguard Worker        probs = torch.randn(1000000, device="cuda").clamp(min=0) * 3e-5
896*da0073e9SAndroid Build Coastguard Worker        samples = probs.multinomial(1000000, replacement=True)
897*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(probs[samples].min().item(), 0)
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker    def _spawn_test_multinomial_invalid_probs_cuda(self, probs):
900*da0073e9SAndroid Build Coastguard Worker        import subprocess
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker        try:
903*da0073e9SAndroid Build Coastguard Worker            p = subprocess.Popen(
904*da0073e9SAndroid Build Coastguard Worker                [
905*da0073e9SAndroid Build Coastguard Worker                    sys.executable,
906*da0073e9SAndroid Build Coastguard Worker                    "-c",
907*da0073e9SAndroid Build Coastguard Worker                    f"""\
908*da0073e9SAndroid Build Coastguard Workerimport sys
909*da0073e9SAndroid Build Coastguard Workerimport torch
910*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
911*da0073e9SAndroid Build Coastguard Workertry:
912*da0073e9SAndroid Build Coastguard Worker    with torch.random.fork_rng(devices=[0]):
913*da0073e9SAndroid Build Coastguard Worker        torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True)
914*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
915*da0073e9SAndroid Build Coastguard Worker    sys.exit(-1) # Should not be reached
916*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e:
917*da0073e9SAndroid Build Coastguard Worker    sys.exit(-2)
918*da0073e9SAndroid Build Coastguard Worker""",
919*da0073e9SAndroid Build Coastguard Worker                ],
920*da0073e9SAndroid Build Coastguard Worker                stdout=subprocess.PIPE,
921*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.PIPE,
922*da0073e9SAndroid Build Coastguard Worker                universal_newlines=True,
923*da0073e9SAndroid Build Coastguard Worker            )
924*da0073e9SAndroid Build Coastguard Worker            out, err = p.communicate(timeout=10)
925*da0073e9SAndroid Build Coastguard Worker            p.wait(timeout=10)
926*da0073e9SAndroid Build Coastguard Worker        except subprocess.TimeoutExpired as e:
927*da0073e9SAndroid Build Coastguard Worker            p.kill()
928*da0073e9SAndroid Build Coastguard Worker            out, err = p.communicate()
929*da0073e9SAndroid Build Coastguard Worker        expected_messages = [
930*da0073e9SAndroid Build Coastguard Worker            "device-side assert triggered",  # CUDA
931*da0073e9SAndroid Build Coastguard Worker            "Assertion",  # CUDA
932*da0073e9SAndroid Build Coastguard Worker            "HSA_STATUS_ERROR_EXCEPTION",  # ROCm
933*da0073e9SAndroid Build Coastguard Worker            "Device-side assertion",  # ROCm
934*da0073e9SAndroid Build Coastguard Worker        ]
935*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any(msg in out or msg in err for msg in expected_messages))
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker    @slowTest
938*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
939*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
940*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
941*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
942*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
943*da0073e9SAndroid Build Coastguard Worker    )
944*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_invalid_probs_cuda(self):
945*da0073e9SAndroid Build Coastguard Worker        self._spawn_test_multinomial_invalid_probs_cuda([1.0, -1.0, 1.0])
946*da0073e9SAndroid Build Coastguard Worker        self._spawn_test_multinomial_invalid_probs_cuda([1.0, inf, 1.0])
947*da0073e9SAndroid Build Coastguard Worker        self._spawn_test_multinomial_invalid_probs_cuda([1.0, -inf, 1.0])
948*da0073e9SAndroid Build Coastguard Worker        self._spawn_test_multinomial_invalid_probs_cuda([1.0, 1.0, nan])
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    @staticmethod
951*da0073e9SAndroid Build Coastguard Worker    def _mute_init():
952*da0073e9SAndroid Build Coastguard Worker        os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker    def _spawn_method(self, method, arg):
955*da0073e9SAndroid Build Coastguard Worker        ctx = torch.multiprocessing.get_context("spawn")
956*da0073e9SAndroid Build Coastguard Worker        with ctx.Pool(1, initializer=self._mute_init) as pool:
957*da0073e9SAndroid Build Coastguard Worker            errors = pool.map(method, [arg])
958*da0073e9SAndroid Build Coastguard Worker            for e in errors:
959*da0073e9SAndroid Build Coastguard Worker                if "device-side assert triggered" not in str(e):
960*da0073e9SAndroid Build Coastguard Worker                    self.fail(e)
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker    @staticmethod
963*da0073e9SAndroid Build Coastguard Worker    def _test_index_bounds_cuda(idx):
964*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(10, device="cuda")
965*da0073e9SAndroid Build Coastguard Worker        try:
966*da0073e9SAndroid Build Coastguard Worker            y = x[torch.tensor([idx])]
967*da0073e9SAndroid Build Coastguard Worker            return f"x[torch.tensor([{idx})]={y}"
968*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as err:
969*da0073e9SAndroid Build Coastguard Worker            return err
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker    @slowTest
972*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
973*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
974*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
975*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
976*da0073e9SAndroid Build Coastguard Worker    )
977*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
978*da0073e9SAndroid Build Coastguard Worker    def test_index_out_of_bounds_exception_cuda(self):
979*da0073e9SAndroid Build Coastguard Worker        test_method = TestCuda._test_index_bounds_cuda
980*da0073e9SAndroid Build Coastguard Worker        # Test in-bound access works fine
981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
982*da0073e9SAndroid Build Coastguard Worker            test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')"
983*da0073e9SAndroid Build Coastguard Worker        )
984*da0073e9SAndroid Build Coastguard Worker        # Test that indexing out of bounds causes assert
985*da0073e9SAndroid Build Coastguard Worker        self._spawn_method(test_method, 11)
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker    @slowTest
988*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
989*da0073e9SAndroid Build Coastguard Worker    @serialTest()
990*da0073e9SAndroid Build Coastguard Worker    def test_huge_index(self):
991*da0073e9SAndroid Build Coastguard Worker        src = torch.empty(15000000, 45, device="cuda", dtype=torch.long).random_(
992*da0073e9SAndroid Build Coastguard Worker            0, 2**22
993*da0073e9SAndroid Build Coastguard Worker        )
994*da0073e9SAndroid Build Coastguard Worker        idx = torch.randperm(src.shape[0], device="cuda")
995*da0073e9SAndroid Build Coastguard Worker        res = src[idx]
996*da0073e9SAndroid Build Coastguard Worker        res_cpu = src.cpu()[idx.cpu()]
997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.cpu(), res_cpu)
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker    def test_randint_randomness_for_large_range(self) -> None:
1000*da0073e9SAndroid Build Coastguard Worker        # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox
1001*da0073e9SAndroid Build Coastguard Worker        # offsets were not calculated correctly, resulting in reused random states.
1002*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/125224
1003*da0073e9SAndroid Build Coastguard Worker        size = 1_000_000
1004*da0073e9SAndroid Build Coastguard Worker        high = 6_000_000_000  # Keep this above 2**32
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        def run(dev: torch.device) -> int:
1007*da0073e9SAndroid Build Coastguard Worker            # Measure how many unique numbers are generated in 2 consecutive calls to randint. If random states are
1008*da0073e9SAndroid Build Coastguard Worker            # reused, this will yield fewer unique numbers.
1009*da0073e9SAndroid Build Coastguard Worker            gen = torch.Generator(device=dev)
1010*da0073e9SAndroid Build Coastguard Worker            gen.manual_seed(0)
1011*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randint(
1012*da0073e9SAndroid Build Coastguard Worker                0, high, [size], device=dev, generator=gen, dtype=torch.int64
1013*da0073e9SAndroid Build Coastguard Worker            )
1014*da0073e9SAndroid Build Coastguard Worker            t2 = torch.randint(
1015*da0073e9SAndroid Build Coastguard Worker                0, high, [size], device=dev, generator=gen, dtype=torch.int64
1016*da0073e9SAndroid Build Coastguard Worker            )
1017*da0073e9SAndroid Build Coastguard Worker            return torch.stack([t1, t2]).unique().shape[0]
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker        # Use CPU as reference. The results should not deviate too much.
1020*da0073e9SAndroid Build Coastguard Worker        assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float32, torch.double])
1023*da0073e9SAndroid Build Coastguard Worker    def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None:
1024*da0073e9SAndroid Build Coastguard Worker        # Test if random states do not overlap between consecutive rand/randn calls.
1025*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/125224
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        def run(func, dev: torch.device, dtype: torch.dtype) -> int:
1028*da0073e9SAndroid Build Coastguard Worker            # Measure how many unique numbers are generated in 2 consecutive calls. If random states are
1029*da0073e9SAndroid Build Coastguard Worker            # reused, this will yield fewer unique numbers.
1030*da0073e9SAndroid Build Coastguard Worker            size = 1000000
1031*da0073e9SAndroid Build Coastguard Worker            gen = torch.Generator(device=dev)
1032*da0073e9SAndroid Build Coastguard Worker            gen.manual_seed(0)
1033*da0073e9SAndroid Build Coastguard Worker            t1 = func((size,), device=dev, generator=gen, dtype=dtype)
1034*da0073e9SAndroid Build Coastguard Worker            t2 = func((size,), device=dev, generator=gen, dtype=dtype)
1035*da0073e9SAndroid Build Coastguard Worker            return torch.stack([t1, t2]).unique().shape[0]
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        # Use CPU as reference. The results should not deviate too much.
1038*da0073e9SAndroid Build Coastguard Worker        for func in [torch.rand, torch.randn]:
1039*da0073e9SAndroid Build Coastguard Worker            deviation = abs(
1040*da0073e9SAndroid Build Coastguard Worker                run(func, torch.device("cuda"), dtype)
1041*da0073e9SAndroid Build Coastguard Worker                - run(func, torch.device("cpu"), dtype)
1042*da0073e9SAndroid Build Coastguard Worker            )
1043*da0073e9SAndroid Build Coastguard Worker            assert deviation < 50_000, deviation
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker    def test_min_max_inits(self):
1046*da0073e9SAndroid Build Coastguard Worker        # Testing if THC_reduceAll received the correct index initialization.
1047*da0073e9SAndroid Build Coastguard Worker        # This affects the result of THC_reduceAll operations at extreme values
1048*da0073e9SAndroid Build Coastguard Worker        x = torch.cuda.ByteTensor([0])
1049*da0073e9SAndroid Build Coastguard Worker        y = torch.cuda.ByteTensor([255])
1050*da0073e9SAndroid Build Coastguard Worker        expected = torch.cuda.LongTensor([0])[0]
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        _, v = x.max(dim=0)
1053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v, expected)
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker        _, v = y.min(dim=0)
1056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v, expected)
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker    def test_nvtx(self):
1059*da0073e9SAndroid Build Coastguard Worker        # Just making sure we can see the symbols
1060*da0073e9SAndroid Build Coastguard Worker        torch.cuda.nvtx.range_push("foo")
1061*da0073e9SAndroid Build Coastguard Worker        torch.cuda.nvtx.mark("bar")
1062*da0073e9SAndroid Build Coastguard Worker        torch.cuda.nvtx.range_pop()
1063*da0073e9SAndroid Build Coastguard Worker        range_handle = torch.cuda.nvtx.range_start("range_start")
1064*da0073e9SAndroid Build Coastguard Worker        torch.cuda.nvtx.range_end(range_handle)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    def test_bincount_ext(self):
1067*da0073e9SAndroid Build Coastguard Worker        # ensure CUDA code coverage
1068*da0073e9SAndroid Build Coastguard Worker        input_size = (100000,)
1069*da0073e9SAndroid Build Coastguard Worker        w = torch.randn(input_size, dtype=torch.double, device="cuda")
1070*da0073e9SAndroid Build Coastguard Worker        w_cpu = w.cpu()
1071*da0073e9SAndroid Build Coastguard Worker        # test shared memory impl
1072*da0073e9SAndroid Build Coastguard Worker        t = torch.randint(50, input_size, dtype=torch.int8, device="cuda")
1073*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(), t.bincount())
1074*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1075*da0073e9SAndroid Build Coastguard Worker        # test global memory impl
1076*da0073e9SAndroid Build Coastguard Worker        #   see `CUDAHistogramMemoryType` in SummaryOps.cu
1077*da0073e9SAndroid Build Coastguard Worker        #   50000 * sizeof(int64_t) == 390 KiB, which should exceed smem of any known GPU
1078*da0073e9SAndroid Build Coastguard Worker        t = torch.randint(50000, input_size, dtype=torch.int64, device="cuda")
1079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(), t.bincount())
1080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros([10], dtype=torch.int32, device="cuda")
1083*da0073e9SAndroid Build Coastguard Worker        # 35488 * 65536 as int32 would cause overflow to negative value
1084*da0073e9SAndroid Build Coastguard Worker        # giving negative bin offset
1085*da0073e9SAndroid Build Coastguard Worker        t[0] = 35488
1086*da0073e9SAndroid Build Coastguard Worker        counted = t.bincount(minlength=65536)
1087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sum(counted), 10)
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker    def test_tiny_half_norm_(self):
1090*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(25).cuda().float()
1091*da0073e9SAndroid Build Coastguard Worker        a /= 100000000
1092*da0073e9SAndroid Build Coastguard Worker        b = a.half()
1093*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(b.norm().item(), 0)
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker    def test_norm_type_conversion(self):
1096*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(65536).cuda().half()
1097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker    def test_cuda_memory_leak_detection_propagates_errors(self):
1100*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1101*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"The size of tensor a \(3\) must match"
1102*da0073e9SAndroid Build Coastguard Worker        ):
1103*da0073e9SAndroid Build Coastguard Worker            with self.assertLeaksNoCudaTensors():
1104*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, 1, device="cuda")
1105*da0073e9SAndroid Build Coastguard Worker                y = torch.randn(2, 1, device="cuda")
1106*da0073e9SAndroid Build Coastguard Worker                z = x + y
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory")
1109*da0073e9SAndroid Build Coastguard Worker    @serialTest()
1110*da0073e9SAndroid Build Coastguard Worker    def test_cuda_kernel_loop_overflow(self):
1111*da0073e9SAndroid Build Coastguard Worker        # Issue #24309: In extreme cases, the loop variable could overflow and continue
1112*da0073e9SAndroid Build Coastguard Worker        # the kernel loop with a negative index, causing a RuntimeError (invalid write):
1113*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, 1, 2**30 + 1, dtype=torch.float16, device="cuda")
1114*da0073e9SAndroid Build Coastguard Worker        expected = x[0, 0, 0, 2**30]
1115*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1116*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
1117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[0, 0, 0, 2**30], expected)
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
1120*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
1121*da0073e9SAndroid Build Coastguard Worker    @serialTest()
1122*da0073e9SAndroid Build Coastguard Worker    def test_cuda_kernel_loop_overflow_large(self):
1123*da0073e9SAndroid Build Coastguard Worker        # Make sure input.numel() > INT_MAX is handled:
1124*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, 1, 2**31, dtype=torch.float16, device="cuda")
1125*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "integer out of range"):
1126*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker        # Issue #24309: In extreme cases, the loop variable could overflow and continue
1129*da0073e9SAndroid Build Coastguard Worker        # the kernel loop with a negative index, causing a RuntimeError (invalid write):
1130*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, 1, 2**31 - 1, dtype=torch.float16, device="cuda")
1131*da0073e9SAndroid Build Coastguard Worker        expected = x[0, 0, 0, 2**31 - 2]
1132*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1133*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
1134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y[0, 0, 0, 2**31 - 2], expected)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    # this might create a reference cycle on self...
1137*da0073e9SAndroid Build Coastguard Worker    def _make_multiply_in_stream(self):
1138*da0073e9SAndroid Build Coastguard Worker        class MultiplyInStream(torch.autograd.Function):
1139*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1140*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, val):
1141*da0073e9SAndroid Build Coastguard Worker                ctx.val = val
1142*da0073e9SAndroid Build Coastguard Worker                ctx.stream = torch.cuda.current_stream()
1143*da0073e9SAndroid Build Coastguard Worker                return x * val
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1146*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
1147*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.current_stream(), ctx.stream)
1148*da0073e9SAndroid Build Coastguard Worker                # delays the operation in the background stream
1149*da0073e9SAndroid Build Coastguard Worker                torch.cuda._sleep(1000 * 5000)
1150*da0073e9SAndroid Build Coastguard Worker                return grad * ctx.val, None
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        return MultiplyInStream
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker    @skipCUDANonDefaultStreamIf(True)
1155*da0073e9SAndroid Build Coastguard Worker    def test_streaming_backwards_sync(self):
1156*da0073e9SAndroid Build Coastguard Worker        default_stream = torch.cuda.current_stream()
1157*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker        MultiplyInStream = self._make_multiply_in_stream()
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Worker        # Tests using grads outside the backward() stream context
1162*da0073e9SAndroid Build Coastguard Worker        # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html
1163*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device="cuda", requires_grad=True)
1164*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
1165*da0073e9SAndroid Build Coastguard Worker            stream.wait_stream(default_stream)
1166*da0073e9SAndroid Build Coastguard Worker            output = MultiplyInStream.apply(x, 2)
1167*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
1168*da0073e9SAndroid Build Coastguard Worker        # sync needed
1169*da0073e9SAndroid Build Coastguard Worker        default_stream.wait_stream(stream)
1170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones_like(x) * 2)
1171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.current_stream(), default_stream)
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker        # Tests that using grads in the same stream context as backward()
1174*da0073e9SAndroid Build Coastguard Worker        # is safe regardless what streams bwd ops ran on
1175*da0073e9SAndroid Build Coastguard Worker        bwd_ambient_stream = torch.cuda.Stream()
1176*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device="cuda", requires_grad=True)
1177*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
1178*da0073e9SAndroid Build Coastguard Worker            stream.wait_stream(default_stream)
1179*da0073e9SAndroid Build Coastguard Worker            output = MultiplyInStream.apply(x, 3)
1180*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(bwd_ambient_stream):
1181*da0073e9SAndroid Build Coastguard Worker            bwd_ambient_stream.wait_stream(stream)
1182*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
1183*da0073e9SAndroid Build Coastguard Worker            # x was first used on "stream" so its AccumulateGrad leaf should run on "stream".
1184*da0073e9SAndroid Build Coastguard Worker            # The end of backward() should have synced "bwd_ambient_stream" with "stream"
1185*da0073e9SAndroid Build Coastguard Worker            # so it should be safe to use x.grad here without any syncs.
1186*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, torch.ones_like(x) * 3)
1187*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
1190*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190")
1191*da0073e9SAndroid Build Coastguard Worker    def test_streaming_backwards_multiple_streams(self):
1192*da0073e9SAndroid Build Coastguard Worker        MultiplyInStream = self._make_multiply_in_stream()
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        class StreamModel(torch.nn.Module):
1195*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1196*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1197*da0073e9SAndroid Build Coastguard Worker                self.event = torch.cuda.Event()
1198*da0073e9SAndroid Build Coastguard Worker                self.stream0 = torch.cuda.Stream()
1199*da0073e9SAndroid Build Coastguard Worker                self.stream1 = torch.cuda.Stream()
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, x_first_use_on_ambient):
1202*da0073e9SAndroid Build Coastguard Worker                if x_first_use_on_ambient:
1203*da0073e9SAndroid Build Coastguard Worker                    x0 = x.clone()
1204*da0073e9SAndroid Build Coastguard Worker                self.stream0.wait_stream(torch.cuda.current_stream())
1205*da0073e9SAndroid Build Coastguard Worker                self.stream1.wait_stream(torch.cuda.current_stream())
1206*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(self.stream0):
1207*da0073e9SAndroid Build Coastguard Worker                    if not x_first_use_on_ambient:
1208*da0073e9SAndroid Build Coastguard Worker                        x0 = x.clone()
1209*da0073e9SAndroid Build Coastguard Worker                    y0 = MultiplyInStream.apply(x0, 2)
1210*da0073e9SAndroid Build Coastguard Worker                    self.event.record(stream=torch.cuda.current_stream())
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(self.stream1):
1213*da0073e9SAndroid Build Coastguard Worker                    y1 = MultiplyInStream.apply(x, 3)
1214*da0073e9SAndroid Build Coastguard Worker                    self.stream1.wait_event(self.event)
1215*da0073e9SAndroid Build Coastguard Worker                    return y0 + y1
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker        for x_first_use_on_ambient in (True, False):
1220*da0073e9SAndroid Build Coastguard Worker            # the out_of_place=False, iters=1 case stresses if proper syncs are inserted
1221*da0073e9SAndroid Build Coastguard Worker            # when grads are initially None and stolen by backward ops.
1222*da0073e9SAndroid Build Coastguard Worker            for out_of_place, iters in ((True, 1), (False, 1), (False, 5)):
1223*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(stream):
1224*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(5, 5, device="cuda", requires_grad=True)
1225*da0073e9SAndroid Build Coastguard Worker                    model = StreamModel().cuda()
1226*da0073e9SAndroid Build Coastguard Worker                    x.register_hook(
1227*da0073e9SAndroid Build Coastguard Worker                        lambda grad: self.assertEqual(
1228*da0073e9SAndroid Build Coastguard Worker                            torch.cuda.current_stream(),
1229*da0073e9SAndroid Build Coastguard Worker                            stream if x_first_use_on_ambient else model.stream0,
1230*da0073e9SAndroid Build Coastguard Worker                        )
1231*da0073e9SAndroid Build Coastguard Worker                    )
1232*da0073e9SAndroid Build Coastguard Worker                    for p in model.parameters():
1233*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(p.grad is None)
1234*da0073e9SAndroid Build Coastguard Worker                    for i in range(iters):
1235*da0073e9SAndroid Build Coastguard Worker                        loss = model(x, x_first_use_on_ambient).sum()
1236*da0073e9SAndroid Build Coastguard Worker                        if out_of_place:
1237*da0073e9SAndroid Build Coastguard Worker                            x_grad = torch.autograd.grad((loss,), (x,))[0]
1238*da0073e9SAndroid Build Coastguard Worker                        else:
1239*da0073e9SAndroid Build Coastguard Worker                            loss.backward()
1240*da0073e9SAndroid Build Coastguard Worker                # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html
1241*da0073e9SAndroid Build Coastguard Worker                torch.cuda.current_stream().wait_stream(stream)
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker                if out_of_place:
1244*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x_grad, torch.ones_like(x) * 5 * iters)
1245*da0073e9SAndroid Build Coastguard Worker                else:
1246*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.grad, torch.ones_like(x) * 5 * iters)
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker    def test_streaming_backwards_sync_graph_root(self):
1249*da0073e9SAndroid Build Coastguard Worker        # This function tests if bwd ops running on a side stream properly sync with the GraphRoot.
1250*da0073e9SAndroid Build Coastguard Worker        # The potential bug it targets is a race condition. The test uses multiple trials and
1251*da0073e9SAndroid Build Coastguard Worker        # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail,
1252*da0073e9SAndroid Build Coastguard Worker        # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free,
1253*da0073e9SAndroid Build Coastguard Worker        # but failure does guarantee there is a bug.
1254*da0073e9SAndroid Build Coastguard Worker        fwd_bwd_op_stream = torch.cuda.Stream()
1255*da0073e9SAndroid Build Coastguard Worker        bwd_ambient_stream = torch.cuda.Stream()
1256*da0073e9SAndroid Build Coastguard Worker        # We need these streams to be different otherwise the test is meaningless.
1257*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream)
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        size = int(1e3)
1260*da0073e9SAndroid Build Coastguard Worker
1261*da0073e9SAndroid Build Coastguard Worker        a = torch.full((size,), 2.0, device="cuda", requires_grad=True)
1262*da0073e9SAndroid Build Coastguard Worker        b = torch.full((size,), 3.0, device="cuda", requires_grad=True)
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        # I don't think we need any manual record_streams below.
1265*da0073e9SAndroid Build Coastguard Worker        # a and b remain in scope for the entire test.
1266*da0073e9SAndroid Build Coastguard Worker        # c and grad remain in scope for each iteration, and there's a full sync between iterations.
1267*da0073e9SAndroid Build Coastguard Worker        for trial in range(5):
1268*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1269*da0073e9SAndroid Build Coastguard Worker            a.grad = b.grad = None
1270*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(fwd_bwd_op_stream):
1271*da0073e9SAndroid Build Coastguard Worker                c = a * b
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(bwd_ambient_stream):
1274*da0073e9SAndroid Build Coastguard Worker                torch.cuda.synchronize()
1275*da0073e9SAndroid Build Coastguard Worker                # Long-running dummy kernel on bwd_ambient_stream delays filling of grad
1276*da0073e9SAndroid Build Coastguard Worker                torch.cuda._sleep(int(50 * get_cycles_per_ms()))
1277*da0073e9SAndroid Build Coastguard Worker                # Fills grad on bwd_ambient_stream
1278*da0073e9SAndroid Build Coastguard Worker                grad = torch.full((size,), float(trial + 1), device="cuda")
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker                # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if
1281*da0073e9SAndroid Build Coastguard Worker                # bwd ops don't sync with bwd_ambient_stream before consuming grad.
1282*da0073e9SAndroid Build Coastguard Worker                torch.autograd.backward(tensors=c, grad_tensors=grad)
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker                # See https://github.com/pytorch/pytorch/issues/47028
1285*da0073e9SAndroid Build Coastguard Worker                # assertEquals below run on bwd_ambient_stream, so this test may also fail
1286*da0073e9SAndroid Build Coastguard Worker                # if backward() fails to sync with bwd_ambient_stream at the end.
1287*da0073e9SAndroid Build Coastguard Worker                # Synchronizing here works around the issue until a proper fix can be made.
1288*da0073e9SAndroid Build Coastguard Worker                torch.cuda.synchronize()
1289*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
1290*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.grad, grad * b)
1291*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.grad, grad * a)
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker    def test_streaming_backwards_callback(self):
1294*da0073e9SAndroid Build Coastguard Worker        # Tests if autograd callbacks sync properly with respect to leaf streams and
1295*da0073e9SAndroid Build Coastguard Worker        # the user-facing stream surrounding backward(). If it fails, first suspect is
1296*da0073e9SAndroid Build Coastguard Worker        # sync logic where  "final_callbacks_" are called in torch/csrc/autograd/engine.cpp
1297*da0073e9SAndroid Build Coastguard Worker        MultiplyInStream = self._make_multiply_in_stream()
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker        size = int(1e3)
1300*da0073e9SAndroid Build Coastguard Worker        a = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1301*da0073e9SAndroid Build Coastguard Worker        b = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.Stream()
1304*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream()
1305*da0073e9SAndroid Build Coastguard Worker        s2 = torch.cuda.Stream()
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker        stash = []
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker        # sets up a nontrivial structure of leaf streams
1310*da0073e9SAndroid Build Coastguard Worker        s0.wait_stream(torch.cuda.current_stream())
1311*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s0):
1312*da0073e9SAndroid Build Coastguard Worker            c = MultiplyInStream.apply(a, 2)
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker        s1.wait_stream(torch.cuda.current_stream())
1315*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s1):
1316*da0073e9SAndroid Build Coastguard Worker            d = MultiplyInStream.apply(b, 3)
1317*da0073e9SAndroid Build Coastguard Worker            s1.wait_stream(s0)
1318*da0073e9SAndroid Build Coastguard Worker            e = c * d
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker            def clone_leaf_grads():
1321*da0073e9SAndroid Build Coastguard Worker                stash.append(a.grad.clone())
1322*da0073e9SAndroid Build Coastguard Worker                stash.append(b.grad.clone())
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker            # Use a hook on e to install the callback
1325*da0073e9SAndroid Build Coastguard Worker            e.register_hook(
1326*da0073e9SAndroid Build Coastguard Worker                lambda grad: torch.autograd.Variable._execution_engine.queue_callback(
1327*da0073e9SAndroid Build Coastguard Worker                    clone_leaf_grads
1328*da0073e9SAndroid Build Coastguard Worker                )
1329*da0073e9SAndroid Build Coastguard Worker            )
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker        s2.wait_stream(s1)
1332*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s2):
1333*da0073e9SAndroid Build Coastguard Worker            e.sum().backward()
1334*da0073e9SAndroid Build Coastguard Worker            # The autograd engine should sync s2 with all leaf streams then run the callback clone_leaf_grads on s2.
1335*da0073e9SAndroid Build Coastguard Worker            # If those things happened properly, checking the values of the cloned grads on s2 should be safe:
1336*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stash[0], torch.full_like(a, 6))
1337*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stash[1], torch.full_like(a, 6))
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1340*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM,
1341*da0073e9SAndroid Build Coastguard Worker        "In ROCm, kernel asserts are disabled due to performance overhead",
1342*da0073e9SAndroid Build Coastguard Worker    )
1343*da0073e9SAndroid Build Coastguard Worker    def test_fixed_cuda_assert_async(self):
1344*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1345*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Boolean value of Tensor with no values is ambiguous"
1346*da0073e9SAndroid Build Coastguard Worker        ):
1347*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor([], device="cuda"))
1348*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1349*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1350*da0073e9SAndroid Build Coastguard Worker            "Boolean value of Tensor with more than one value is ambiguous",
1351*da0073e9SAndroid Build Coastguard Worker        ):
1352*da0073e9SAndroid Build Coastguard Worker            torch._assert_async(torch.tensor([0, 0], device="cuda"))
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(1, device="cuda"))
1355*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(0.1, device="cuda"))
1356*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(-0.1, device="cuda"))
1357*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(True, device="cuda"))
1358*da0073e9SAndroid Build Coastguard Worker        torch._assert_async(torch.tensor(0 + 0.1j, device="cuda"))
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Worker        fail_stmts = [
1361*da0073e9SAndroid Build Coastguard Worker            "torch._assert_async(torch.tensor(0, device='cuda'))",
1362*da0073e9SAndroid Build Coastguard Worker            "torch._assert_async(torch.tensor(0.0, device='cuda'))",
1363*da0073e9SAndroid Build Coastguard Worker            "torch._assert_async(torch.tensor(False, device='cuda'))",
1364*da0073e9SAndroid Build Coastguard Worker            "torch._assert_async(torch.tensor(0 + 0j, device='cuda'))",
1365*da0073e9SAndroid Build Coastguard Worker        ]
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker        import subprocess
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker        for stmt in fail_stmts:
1370*da0073e9SAndroid Build Coastguard Worker            with self.subTest(stmt=stmt):
1371*da0073e9SAndroid Build Coastguard Worker                r = subprocess.call(
1372*da0073e9SAndroid Build Coastguard Worker                    [
1373*da0073e9SAndroid Build Coastguard Worker                        sys.executable,
1374*da0073e9SAndroid Build Coastguard Worker                        "-c",
1375*da0073e9SAndroid Build Coastguard Worker                        f"""\
1376*da0073e9SAndroid Build Coastguard Workerimport torch
1377*da0073e9SAndroid Build Coastguard Worker
1378*da0073e9SAndroid Build Coastguard Worker{stmt}
1379*da0073e9SAndroid Build Coastguard Workertorch.cuda.synchronize()
1380*da0073e9SAndroid Build Coastguard Worker""",
1381*da0073e9SAndroid Build Coastguard Worker                    ]
1382*da0073e9SAndroid Build Coastguard Worker                )
1383*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(r != 0)
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL")
1386*da0073e9SAndroid Build Coastguard Worker    def test_cublas_multiple_threads_same_device(self):
1387*da0073e9SAndroid Build Coastguard Worker        # Note, these parameters should be very carefully tuned
1388*da0073e9SAndroid Build Coastguard Worker        # Too small number makes it hard for the racing condition
1389*da0073e9SAndroid Build Coastguard Worker        # to happen, while too large number sometimes cause hang
1390*da0073e9SAndroid Build Coastguard Worker        size = 1024
1391*da0073e9SAndroid Build Coastguard Worker        num_threads = 2
1392*da0073e9SAndroid Build Coastguard Worker        trials = 3
1393*da0073e9SAndroid Build Coastguard Worker        test_iters = 100
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker        weight = torch.ones((size, size), device="cuda")
1396*da0073e9SAndroid Build Coastguard Worker        results = {}
1397*da0073e9SAndroid Build Coastguard Worker        barrier = threading.Barrier(num_threads)
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker        def _worker(t):
1400*da0073e9SAndroid Build Coastguard Worker            my_stream = torch.cuda.Stream()
1401*da0073e9SAndroid Build Coastguard Worker            # Hard sync so we don't need to worry about creating and using tensors
1402*da0073e9SAndroid Build Coastguard Worker            # across streams or the fact that default streams are thread-local.
1403*da0073e9SAndroid Build Coastguard Worker            # Those issues are not the target of this test.
1404*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1405*da0073e9SAndroid Build Coastguard Worker            # Line up threads to increase likelihood of race conditions.
1406*da0073e9SAndroid Build Coastguard Worker            barrier.wait()
1407*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(my_stream):
1408*da0073e9SAndroid Build Coastguard Worker                for i in range(test_iters):
1409*da0073e9SAndroid Build Coastguard Worker                    # If all threads are sharing the same cublas handle,
1410*da0073e9SAndroid Build Coastguard Worker                    # the following sequence may occur:
1411*da0073e9SAndroid Build Coastguard Worker                    # thread 0 calls cublasSetStream()
1412*da0073e9SAndroid Build Coastguard Worker                    # thread 1 calls cublasSetStream()
1413*da0073e9SAndroid Build Coastguard Worker                    # thread 0 launches its raw gemm, which it thinks is in
1414*da0073e9SAndroid Build Coastguard Worker                    #          its own stream, but is actually in thread 1's stream.
1415*da0073e9SAndroid Build Coastguard Worker                    # thread 0 enqueues its div_, which IS is its own stream,
1416*da0073e9SAndroid Build Coastguard Worker                    #          but actually now races with its gemm.
1417*da0073e9SAndroid Build Coastguard Worker                    results[t] = torch.mm(results[t], weight)
1418*da0073e9SAndroid Build Coastguard Worker                    results[t].div_(float(size))
1419*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker        for _ in range(trials):
1422*da0073e9SAndroid Build Coastguard Worker            for t in range(num_threads):
1423*da0073e9SAndroid Build Coastguard Worker                results[t] = torch.ones((size, size), device="cuda")
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker            threads = [
1426*da0073e9SAndroid Build Coastguard Worker                threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1427*da0073e9SAndroid Build Coastguard Worker            ]
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker            for thread in threads:
1430*da0073e9SAndroid Build Coastguard Worker                thread.start()
1431*da0073e9SAndroid Build Coastguard Worker            for thread in threads:
1432*da0073e9SAndroid Build Coastguard Worker                thread.join()
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker            for t in range(num_threads):
1435*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(results[t].sum().item(), size * size)
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker    # Test is flaky on Windows (https://github.com/pytorch/pytorch/issues/57401)
1438*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows (see issue 57401)")
1439*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1440*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
1441*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_multiple_threads_same_device(self):
1442*da0073e9SAndroid Build Coastguard Worker        # This function is intended to test the lazy creation and reuse of per-thread
1443*da0073e9SAndroid Build Coastguard Worker        # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp.
1444*da0073e9SAndroid Build Coastguard Worker        # Failure here likely indicates something wrong with that logic.
1445*da0073e9SAndroid Build Coastguard Worker        weight = torch.ones((1, 1, 2, 2), device="cuda")
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker        results = {}
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        num_threads = 2
1450*da0073e9SAndroid Build Coastguard Worker        trials = 3
1451*da0073e9SAndroid Build Coastguard Worker        test_iters = 1000
1452*da0073e9SAndroid Build Coastguard Worker        barrier = threading.Barrier(num_threads)
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True):
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker            def _worker(t):
1457*da0073e9SAndroid Build Coastguard Worker                my_stream = torch.cuda.Stream()
1458*da0073e9SAndroid Build Coastguard Worker                # Hard sync so we don't need to worry about creating and using tensors
1459*da0073e9SAndroid Build Coastguard Worker                # across streams or the fact that default streams are thread-local.
1460*da0073e9SAndroid Build Coastguard Worker                # Those issues are not the target of this test.
1461*da0073e9SAndroid Build Coastguard Worker                torch.cuda.synchronize()
1462*da0073e9SAndroid Build Coastguard Worker                # Line up threads to increase likelihood of race conditions.
1463*da0073e9SAndroid Build Coastguard Worker                barrier.wait()
1464*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(my_stream):
1465*da0073e9SAndroid Build Coastguard Worker                    for _ in range(test_iters):
1466*da0073e9SAndroid Build Coastguard Worker                        # If all threads are sharing the same cudnn handle,
1467*da0073e9SAndroid Build Coastguard Worker                        # the following sequence may occur:
1468*da0073e9SAndroid Build Coastguard Worker                        # thread 0 calls setCuDNNStreamToCurrent()
1469*da0073e9SAndroid Build Coastguard Worker                        # thread 1 calls setCuDNNStreamToCurrent()
1470*da0073e9SAndroid Build Coastguard Worker                        # thread 0 launches its raw convolution, which it thinks is in
1471*da0073e9SAndroid Build Coastguard Worker                        #          its own stream, but is actually in thread 1's stream.
1472*da0073e9SAndroid Build Coastguard Worker                        # thread 0 enqueues its div_, which IS is its own stream,
1473*da0073e9SAndroid Build Coastguard Worker                        #          but now races with its convolution.
1474*da0073e9SAndroid Build Coastguard Worker                        results[t] = torch.nn.functional.conv2d(
1475*da0073e9SAndroid Build Coastguard Worker                            results[t], weight, padding=0
1476*da0073e9SAndroid Build Coastguard Worker                        )
1477*da0073e9SAndroid Build Coastguard Worker                        results[t].div_(4.0)
1478*da0073e9SAndroid Build Coastguard Worker                torch.cuda.synchronize()
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker            for _ in range(trials):
1481*da0073e9SAndroid Build Coastguard Worker                for t in range(num_threads):
1482*da0073e9SAndroid Build Coastguard Worker                    results[t] = torch.ones((1, 1, 2048, 2048), device="cuda")
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker                threads = [
1485*da0073e9SAndroid Build Coastguard Worker                    threading.Thread(target=_worker, args=(t,))
1486*da0073e9SAndroid Build Coastguard Worker                    for t in range(num_threads)
1487*da0073e9SAndroid Build Coastguard Worker                ]
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker                for thread in threads:
1490*da0073e9SAndroid Build Coastguard Worker                    thread.start()
1491*da0073e9SAndroid Build Coastguard Worker                for thread in threads:
1492*da0073e9SAndroid Build Coastguard Worker                    thread.join()
1493*da0073e9SAndroid Build Coastguard Worker
1494*da0073e9SAndroid Build Coastguard Worker                for t in range(num_threads):
1495*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
1496*da0073e9SAndroid Build Coastguard Worker                        results[t].sum().item(),
1497*da0073e9SAndroid Build Coastguard Worker                        (2048 - test_iters) * (2048 - test_iters),
1498*da0073e9SAndroid Build Coastguard Worker                    )
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    def test_cusparse_multiple_threads_same_device(self):
1501*da0073e9SAndroid Build Coastguard Worker        size = 1024
1502*da0073e9SAndroid Build Coastguard Worker        num_threads = 2
1503*da0073e9SAndroid Build Coastguard Worker        trials = 3
1504*da0073e9SAndroid Build Coastguard Worker        test_iters = 500
1505*da0073e9SAndroid Build Coastguard Worker
1506*da0073e9SAndroid Build Coastguard Worker        def ones_sparse(size):
1507*da0073e9SAndroid Build Coastguard Worker            a = torch.arange(size, device="cuda")
1508*da0073e9SAndroid Build Coastguard Worker            indices = torch.cartesian_prod(a, a).t()
1509*da0073e9SAndroid Build Coastguard Worker            values = torch.ones(size * size, device="cuda")
1510*da0073e9SAndroid Build Coastguard Worker            return torch.sparse_coo_tensor(indices, values)
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker        weight = ones_sparse(size)
1513*da0073e9SAndroid Build Coastguard Worker        results = {}
1514*da0073e9SAndroid Build Coastguard Worker        barrier = threading.Barrier(num_threads)
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        def _worker(t):
1517*da0073e9SAndroid Build Coastguard Worker            my_stream = torch.cuda.Stream()
1518*da0073e9SAndroid Build Coastguard Worker            # Hard sync so we don't need to worry about creating and using tensors
1519*da0073e9SAndroid Build Coastguard Worker            # across streams or the fact that default streams are thread-local.
1520*da0073e9SAndroid Build Coastguard Worker            # Those issues are not the target of this test.
1521*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1522*da0073e9SAndroid Build Coastguard Worker            # Line up threads to increase likelihood of race conditions.
1523*da0073e9SAndroid Build Coastguard Worker            barrier.wait()
1524*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(my_stream):
1525*da0073e9SAndroid Build Coastguard Worker                for i in range(test_iters):
1526*da0073e9SAndroid Build Coastguard Worker                    # If all threads are sharing the same cublas handle,
1527*da0073e9SAndroid Build Coastguard Worker                    # the following sequence may occur:
1528*da0073e9SAndroid Build Coastguard Worker                    # thread 0 calls cublasSetStream()
1529*da0073e9SAndroid Build Coastguard Worker                    # thread 1 calls cublasSetStream()
1530*da0073e9SAndroid Build Coastguard Worker                    # thread 0 launches its raw gemm, which it thinks is in
1531*da0073e9SAndroid Build Coastguard Worker                    #          its own stream, but is actually in thread 1's stream.
1532*da0073e9SAndroid Build Coastguard Worker                    # thread 0 enqueues its div_, which IS is its own stream,
1533*da0073e9SAndroid Build Coastguard Worker                    #          but actually now races with its gemm.
1534*da0073e9SAndroid Build Coastguard Worker                    results[t] = weight.mm(results[t])
1535*da0073e9SAndroid Build Coastguard Worker                    results[t].div_(float(size))
1536*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        for _ in range(trials):
1539*da0073e9SAndroid Build Coastguard Worker            for t in range(num_threads):
1540*da0073e9SAndroid Build Coastguard Worker                results[t] = torch.ones((size, size), device="cuda")
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker            threads = [
1543*da0073e9SAndroid Build Coastguard Worker                threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1544*da0073e9SAndroid Build Coastguard Worker            ]
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker            for thread in threads:
1547*da0073e9SAndroid Build Coastguard Worker                thread.start()
1548*da0073e9SAndroid Build Coastguard Worker            for thread in threads:
1549*da0073e9SAndroid Build Coastguard Worker                thread.join()
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker            for t in range(num_threads):
1552*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(results[t].sum().item(), size * size)
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker    @slowTest
1555*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
1556*da0073e9SAndroid Build Coastguard Worker    @serialTest()
1557*da0073e9SAndroid Build Coastguard Worker    def test_max_large_axis(self):
1558*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2**32, device="cuda", dtype=torch.int8)
1559*da0073e9SAndroid Build Coastguard Worker        x[-1] = 1
1560*da0073e9SAndroid Build Coastguard Worker        val, idx = x.max(0)
1561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(val, 1)
1562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(idx, x.shape[0] - 1)
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1565*da0073e9SAndroid Build Coastguard Worker    def test_to_numpy(self):
1566*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy())
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker    def test_graph_is_current_stream_capturing(self):
1569*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.cuda.is_current_stream_capturing())
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA and (not TEST_WITH_ROCM):
1572*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
1573*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
1574*da0073e9SAndroid Build Coastguard Worker                g = torch.cuda.CUDAGraph()
1575*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.cuda.is_current_stream_capturing())
1576*da0073e9SAndroid Build Coastguard Worker                g.capture_begin()
1577*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.cuda.is_current_stream_capturing())
1578*da0073e9SAndroid Build Coastguard Worker                g.capture_end()
1579*da0073e9SAndroid Build Coastguard Worker
1580*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1581*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1582*da0073e9SAndroid Build Coastguard Worker    )
1583*da0073e9SAndroid Build Coastguard Worker    def test_graph_capture_simple(self):
1584*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
1585*da0073e9SAndroid Build Coastguard Worker
1586*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
1587*da0073e9SAndroid Build Coastguard Worker            a = torch.full((1000,), 1, device="cuda")
1588*da0073e9SAndroid Build Coastguard Worker            g = torch.cuda.CUDAGraph()
1589*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
1590*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
1591*da0073e9SAndroid Build Coastguard Worker            b = a
1592*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
1593*da0073e9SAndroid Build Coastguard Worker                b = b + 1
1594*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
1595*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker        g.replay()
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.sum().item() == 11000.0)
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1602*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1603*da0073e9SAndroid Build Coastguard Worker    )
1604*da0073e9SAndroid Build Coastguard Worker    def test_graphsafe_set_get_rng_state(self):
1605*da0073e9SAndroid Build Coastguard Worker        # Define a function to create generator states, with optional graph registration
1606*da0073e9SAndroid Build Coastguard Worker        def create_states(generator):
1607*da0073e9SAndroid Build Coastguard Worker            """Initializes generator states and registers them with a CUDA graph if provided."""
1608*da0073e9SAndroid Build Coastguard Worker            # Ensure the CUDA generator is initialized
1609*da0073e9SAndroid Build Coastguard Worker            torch.rand(1, device="cuda")
1610*da0073e9SAndroid Build Coastguard Worker            generator.manual_seed(0)
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker            # Save the current state of the generator
1613*da0073e9SAndroid Build Coastguard Worker            old_state = generator.graphsafe_get_state()
1614*da0073e9SAndroid Build Coastguard Worker            # Create and save a cloned state of the generator
1615*da0073e9SAndroid Build Coastguard Worker            new_state = generator.clone_state()
1616*da0073e9SAndroid Build Coastguard Worker            # Return the original generator and its two states
1617*da0073e9SAndroid Build Coastguard Worker            return generator, old_state, new_state
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        def register_states_to_graph(generator_state, graph):
1620*da0073e9SAndroid Build Coastguard Worker            generator, old_state, new_state = generator_state
1621*da0073e9SAndroid Build Coastguard Worker            graph.register_generator_state(old_state)
1622*da0073e9SAndroid Build Coastguard Worker            graph.register_generator_state(new_state)
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        # Define a function to perform specific RNG actions using the generator's states
1625*da0073e9SAndroid Build Coastguard Worker        def perform_random_generation_steps(generator_state):
1626*da0073e9SAndroid Build Coastguard Worker            generator, old_state, new_state = generator_state
1627*da0073e9SAndroid Build Coastguard Worker            random_values = []
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker            # Generate random numbers with the new generator state
1630*da0073e9SAndroid Build Coastguard Worker            generator.graphsafe_set_state(new_state)
1631*da0073e9SAndroid Build Coastguard Worker            random_values.append(torch.rand(5, device="cuda", generator=generator))
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker            # Generate random numbers twice with the old generator state
1634*da0073e9SAndroid Build Coastguard Worker            generator.graphsafe_set_state(old_state)
1635*da0073e9SAndroid Build Coastguard Worker            random_values.extend(
1636*da0073e9SAndroid Build Coastguard Worker                [torch.rand(5, device="cuda", generator=generator) for _ in range(2)]
1637*da0073e9SAndroid Build Coastguard Worker            )
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker            return random_values
1640*da0073e9SAndroid Build Coastguard Worker
1641*da0073e9SAndroid Build Coastguard Worker        # Define a function to retrieve the final offsets of the original and new generator states
1642*da0073e9SAndroid Build Coastguard Worker        def get_final_offsets_of_states(generator_state):
1643*da0073e9SAndroid Build Coastguard Worker            generator, old_state, new_state = generator_state
1644*da0073e9SAndroid Build Coastguard Worker            old_state_offset = old_state.get_offset()
1645*da0073e9SAndroid Build Coastguard Worker            new_state_offset = new_state.get_offset()
1646*da0073e9SAndroid Build Coastguard Worker            return old_state_offset, new_state_offset
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Worker        # Set up and test a new CUDA generator
1649*da0073e9SAndroid Build Coastguard Worker        generator = torch.Generator(device="cuda")
1650*da0073e9SAndroid Build Coastguard Worker        generator_state = create_states(generator)
1651*da0073e9SAndroid Build Coastguard Worker
1652*da0073e9SAndroid Build Coastguard Worker        # Set up and test the default CUDA generator with a CUDA Graph
1653*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
1654*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
1655*da0073e9SAndroid Build Coastguard Worker        default_generator = torch.cuda.default_generators[0]
1656*da0073e9SAndroid Build Coastguard Worker        default_generator_state = create_states(default_generator)
1657*da0073e9SAndroid Build Coastguard Worker        register_states_to_graph(default_generator_state, g)
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker        # Perform random number generation within a CUDA graph
1660*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
1661*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
1662*da0073e9SAndroid Build Coastguard Worker            graphed_random_values = perform_random_generation_steps(
1663*da0073e9SAndroid Build Coastguard Worker                default_generator_state
1664*da0073e9SAndroid Build Coastguard Worker            )
1665*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
1666*da0073e9SAndroid Build Coastguard Worker
1667*da0073e9SAndroid Build Coastguard Worker        # Synchronize the streams and replay the graph
1668*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
1669*da0073e9SAndroid Build Coastguard Worker        for _ in range(3):
1670*da0073e9SAndroid Build Coastguard Worker            random_values = perform_random_generation_steps(generator_state)
1671*da0073e9SAndroid Build Coastguard Worker            g.replay()
1672*da0073e9SAndroid Build Coastguard Worker            offset = get_final_offsets_of_states(generator_state)
1673*da0073e9SAndroid Build Coastguard Worker            graph_offset = get_final_offsets_of_states(default_generator_state)
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker            # Compare the final offsets of states for both generators to ensure consistency
1676*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(offset == graph_offset)
1677*da0073e9SAndroid Build Coastguard Worker            # Compare the states generated outside and inside the graph
1678*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(random_values, graphed_random_values)
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1681*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1682*da0073e9SAndroid Build Coastguard Worker    )
1683*da0073e9SAndroid Build Coastguard Worker    def test_memory_stats_of_multiple_generators_and_graphs(self):
1684*da0073e9SAndroid Build Coastguard Worker        # Function to clear CUDA cache and collect garbage
1685*da0073e9SAndroid Build Coastguard Worker        def clear_cuda_cache():
1686*da0073e9SAndroid Build Coastguard Worker            gc.collect()
1687*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
1688*da0073e9SAndroid Build Coastguard Worker
1689*da0073e9SAndroid Build Coastguard Worker        # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph.
1690*da0073e9SAndroid Build Coastguard Worker        def simple_graph_task(graph):
1691*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
1692*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
1693*da0073e9SAndroid Build Coastguard Worker                graph.capture_begin()
1694*da0073e9SAndroid Build Coastguard Worker                torch.rand(1, device="cuda")
1695*da0073e9SAndroid Build Coastguard Worker                graph.capture_end()
1696*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(s)
1697*da0073e9SAndroid Build Coastguard Worker            graph.replay()  # Replays the captured operations
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker        def get_memory_stats():
1700*da0073e9SAndroid Build Coastguard Worker            stats = torch.cuda.memory_stats()
1701*da0073e9SAndroid Build Coastguard Worker            num_blocks = stats["active.all.current"]
1702*da0073e9SAndroid Build Coastguard Worker            total_size = stats["active_bytes.all.current"]
1703*da0073e9SAndroid Build Coastguard Worker            return num_blocks, total_size
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker        def test(num_graphs, num_generators):
1706*da0073e9SAndroid Build Coastguard Worker            baseline = get_memory_stats()
1707*da0073e9SAndroid Build Coastguard Worker            baseline_num_blocks, baseline_total_size = baseline
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker            # Allocate CUDA graphs
1710*da0073e9SAndroid Build Coastguard Worker            graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)]
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker            # Allocate and manage generator states
1713*da0073e9SAndroid Build Coastguard Worker            default_generator = torch.cuda.default_generators[0]
1714*da0073e9SAndroid Build Coastguard Worker            generators = [default_generator.graphsafe_get_state()]
1715*da0073e9SAndroid Build Coastguard Worker
1716*da0073e9SAndroid Build Coastguard Worker            # Starts from 1 as one state is already added
1717*da0073e9SAndroid Build Coastguard Worker            for _ in range(1, num_generators):
1718*da0073e9SAndroid Build Coastguard Worker                generators.append(default_generator.clone_state())
1719*da0073e9SAndroid Build Coastguard Worker
1720*da0073e9SAndroid Build Coastguard Worker            for graph in graphs:
1721*da0073e9SAndroid Build Coastguard Worker                for generator_state in generators:
1722*da0073e9SAndroid Build Coastguard Worker                    graph.register_generator_state(generator_state)
1723*da0073e9SAndroid Build Coastguard Worker                simple_graph_task(graph)
1724*da0073e9SAndroid Build Coastguard Worker
1725*da0073e9SAndroid Build Coastguard Worker            # Assert conditions after graph tasks
1726*da0073e9SAndroid Build Coastguard Worker            num_blocks, total_size = get_memory_stats()
1727*da0073e9SAndroid Build Coastguard Worker            # The allocated blocks should only be proportional to the number of generators
1728*da0073e9SAndroid Build Coastguard Worker            expected_blocks_diff = 2 * num_generators
1729*da0073e9SAndroid Build Coastguard Worker            expected_size_diff = 2 * 512 * num_generators  # Each block's size is 512
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1732*da0073e9SAndroid Build Coastguard Worker                (num_blocks - baseline_num_blocks) == expected_blocks_diff,
1733*da0073e9SAndroid Build Coastguard Worker                "Unexpected number of active blocks.",
1734*da0073e9SAndroid Build Coastguard Worker            )
1735*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1736*da0073e9SAndroid Build Coastguard Worker                (total_size - baseline_total_size) == expected_size_diff,
1737*da0073e9SAndroid Build Coastguard Worker                "Unexpected total memory size.",
1738*da0073e9SAndroid Build Coastguard Worker            )
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Worker            # Cleanup graphs and clear CUDA cache
1741*da0073e9SAndroid Build Coastguard Worker            while graphs:
1742*da0073e9SAndroid Build Coastguard Worker                graph = graphs.pop()
1743*da0073e9SAndroid Build Coastguard Worker                del graph
1744*da0073e9SAndroid Build Coastguard Worker            clear_cuda_cache()
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker            # Assert that memory stats return to baseline after cleanup
1747*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1748*da0073e9SAndroid Build Coastguard Worker                get_memory_stats() == baseline,
1749*da0073e9SAndroid Build Coastguard Worker                "Memory stats do not match baseline after cleanup.",
1750*da0073e9SAndroid Build Coastguard Worker            )
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Worker        # Running the test function with different parameters
1753*da0073e9SAndroid Build Coastguard Worker        test(1, 1)
1754*da0073e9SAndroid Build Coastguard Worker        test(3, 2)
1755*da0073e9SAndroid Build Coastguard Worker        test(10, 20)
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1758*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1759*da0073e9SAndroid Build Coastguard Worker    )
1760*da0073e9SAndroid Build Coastguard Worker    def test_graph_capture_reset_recapture(self):
1761*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
1762*da0073e9SAndroid Build Coastguard Worker
1763*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
1764*da0073e9SAndroid Build Coastguard Worker            a = torch.full((1000,), 1, device="cuda")
1765*da0073e9SAndroid Build Coastguard Worker            g = torch.cuda.CUDAGraph()
1766*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
1767*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
1768*da0073e9SAndroid Build Coastguard Worker            b = a
1769*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
1770*da0073e9SAndroid Build Coastguard Worker                b = b + 1
1771*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
1772*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
1773*da0073e9SAndroid Build Coastguard Worker
1774*da0073e9SAndroid Build Coastguard Worker        g.replay()
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.sum().item() == 11000.0)
1777*da0073e9SAndroid Build Coastguard Worker
1778*da0073e9SAndroid Build Coastguard Worker        g.reset()
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
1781*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
1782*da0073e9SAndroid Build Coastguard Worker            b.fill_(2.0)
1783*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
1784*da0073e9SAndroid Build Coastguard Worker                b = b + 2
1785*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
1786*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Worker        g.replay()
1789*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.sum().item() == 22000.0)
1790*da0073e9SAndroid Build Coastguard Worker
1791*da0073e9SAndroid Build Coastguard Worker        g.reset()
1792*da0073e9SAndroid Build Coastguard Worker        del g
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1795*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1796*da0073e9SAndroid Build Coastguard Worker    )
1797*da0073e9SAndroid Build Coastguard Worker    def test_graph_debugdump(self):
1798*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
1799*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10240000, device="cuda")
1800*da0073e9SAndroid Build Coastguard Worker        y = torch.rand_like(x)
1801*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
1802*da0073e9SAndroid Build Coastguard Worker        g.enable_debug_mode()
1803*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.Stream()
1804*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream()
1805*da0073e9SAndroid Build Coastguard Worker        s0.wait_stream(torch.cuda.current_stream())
1806*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s0):
1807*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
1808*da0073e9SAndroid Build Coastguard Worker            z = x + y
1809*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s1):
1810*da0073e9SAndroid Build Coastguard Worker                s1.wait_stream(s0)
1811*da0073e9SAndroid Build Coastguard Worker                w = z + y
1812*da0073e9SAndroid Build Coastguard Worker            s0.wait_stream(s1)
1813*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
1814*da0073e9SAndroid Build Coastguard Worker        s0.synchronize()
1815*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
1816*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tempdir:
1817*da0073e9SAndroid Build Coastguard Worker            g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot"))
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1820*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1821*da0073e9SAndroid Build Coastguard Worker    )
1822*da0073e9SAndroid Build Coastguard Worker    def test_graph_error(self):
1823*da0073e9SAndroid Build Coastguard Worker        # We need to run this test in a separate thread as the error we trigger
1824*da0073e9SAndroid Build Coastguard Worker        # puts the cuda context in a bad state
1825*da0073e9SAndroid Build Coastguard Worker        script = """
1826*da0073e9SAndroid Build Coastguard Workerimport torch
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Workerg = torch.cuda.CUDAGraph()
1829*da0073e9SAndroid Build Coastguard Workertry:
1830*da0073e9SAndroid Build Coastguard Worker    g.capture_begin()
1831*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e:
1832*da0073e9SAndroid Build Coastguard Worker    if "CUDA graphs must be captured on a non-default stream." in str(e):
1833*da0073e9SAndroid Build Coastguard Worker        exit(0)
1834*da0073e9SAndroid Build Coastguard Worker    else:
1835*da0073e9SAndroid Build Coastguard Worker        exit(1)
1836*da0073e9SAndroid Build Coastguard Workerexit(2)
1837*da0073e9SAndroid Build Coastguard Worker"""
1838*da0073e9SAndroid Build Coastguard Worker        try:
1839*da0073e9SAndroid Build Coastguard Worker            a = subprocess.check_output(
1840*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-c", script],
1841*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.STDOUT,
1842*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
1843*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
1844*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),
1845*da0073e9SAndroid Build Coastguard Worker            )
1846*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
1847*da0073e9SAndroid Build Coastguard Worker            if e.returncode == 1:
1848*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1849*da0073e9SAndroid Build Coastguard Worker                    False,
1850*da0073e9SAndroid Build Coastguard Worker                    "Error raise by starting capture without a stream is not the expected one",
1851*da0073e9SAndroid Build Coastguard Worker                )
1852*da0073e9SAndroid Build Coastguard Worker            elif e.returncode == 2:
1853*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1854*da0073e9SAndroid Build Coastguard Worker                    False,
1855*da0073e9SAndroid Build Coastguard Worker                    "Error raised by starting capture without a stream was not caught",
1856*da0073e9SAndroid Build Coastguard Worker                )
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1859*da0073e9SAndroid Build Coastguard Worker        (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11,
1860*da0073e9SAndroid Build Coastguard Worker        "CUDA >= 11.0 required for graphs",
1861*da0073e9SAndroid Build Coastguard Worker    )
1862*da0073e9SAndroid Build Coastguard Worker    def test_graph_warn_if_has_zero_nodes(self):
1863*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as caught:
1864*da0073e9SAndroid Build Coastguard Worker            g = torch.cuda.CUDAGraph()
1865*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
1866*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
1867*da0073e9SAndroid Build Coastguard Worker                g.capture_begin()
1868*da0073e9SAndroid Build Coastguard Worker                g.capture_end()
1869*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1870*da0073e9SAndroid Build Coastguard Worker            any("The CUDA Graph is empty" in str(w.message) for w in caught)
1871*da0073e9SAndroid Build Coastguard Worker        )
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1874*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1875*da0073e9SAndroid Build Coastguard Worker    )
1876*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1877*da0073e9SAndroid Build Coastguard Worker        IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
1878*da0073e9SAndroid Build Coastguard Worker    )
1879*da0073e9SAndroid Build Coastguard Worker    def test_graph_capture_oom(self):
1880*da0073e9SAndroid Build Coastguard Worker        oom_regex = (
1881*da0073e9SAndroid Build Coastguard Worker            "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
1882*da0073e9SAndroid Build Coastguard Worker        )
1883*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, oom_regex):
1884*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.graph(torch.cuda.CUDAGraph()):
1885*da0073e9SAndroid Build Coastguard Worker                torch.zeros(2**40, device="cuda")
1886*da0073e9SAndroid Build Coastguard Worker
1887*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1888*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1889*da0073e9SAndroid Build Coastguard Worker    )
1890*da0073e9SAndroid Build Coastguard Worker    @serialTest()
1891*da0073e9SAndroid Build Coastguard Worker    def test_repeat_graph_capture_cublas_workspace_memory(self):
1892*da0073e9SAndroid Build Coastguard Worker        (x, y, z) = 1024, 512, 64
1893*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((x, y), device="cuda")
1894*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((y, z), device="cuda")
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker        # warmup
1897*da0073e9SAndroid Build Coastguard Worker        torch.mm(a, b)
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker        free_bytes_before, total_bytes = torch.cuda.mem_get_info()
1900*da0073e9SAndroid Build Coastguard Worker        used_gb_before = (total_bytes - free_bytes_before) / 1e9
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
1903*da0073e9SAndroid Build Coastguard Worker            torch_graph = torch.cuda.CUDAGraph()
1904*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.graph(torch_graph):
1905*da0073e9SAndroid Build Coastguard Worker                torch.mm(a, b)
1906*da0073e9SAndroid Build Coastguard Worker            torch_graph.replay()
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker        free_bytes_after, _ = torch.cuda.mem_get_info()
1909*da0073e9SAndroid Build Coastguard Worker        used_gb_after = (total_bytes - free_bytes_after) / 1e9
1910*da0073e9SAndroid Build Coastguard Worker
1911*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(used_gb_before + 0.1 < used_gb_after)
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1914*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1915*da0073e9SAndroid Build Coastguard Worker    )
1916*da0073e9SAndroid Build Coastguard Worker    def test_graph_rng_functional(self):
1917*da0073e9SAndroid Build Coastguard Worker        ops_with_kwargs = (
1918*da0073e9SAndroid Build Coastguard Worker            (torch.nn.functional.dropout, {"p": 0.1}),
1919*da0073e9SAndroid Build Coastguard Worker            (torch.nn.functional.rrelu, {"training": True}),
1920*da0073e9SAndroid Build Coastguard Worker        )
1921*da0073e9SAndroid Build Coastguard Worker        size = 10000
1922*da0073e9SAndroid Build Coastguard Worker
1923*da0073e9SAndroid Build Coastguard Worker        def run(op, kwargs):
1924*da0073e9SAndroid Build Coastguard Worker            a = torch.randn((size,), device="cuda", dtype=torch.float)
1925*da0073e9SAndroid Build Coastguard Worker
1926*da0073e9SAndroid Build Coastguard Worker            # Control
1927*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(5)
1928*da0073e9SAndroid Build Coastguard Worker            eager_out = a
1929*da0073e9SAndroid Build Coastguard Worker            for _ in range(6):
1930*da0073e9SAndroid Build Coastguard Worker                eager_out = op(eager_out, **kwargs)
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker            graph_in = a.clone()
1933*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.Stream()
1934*da0073e9SAndroid Build Coastguard Worker            stream.wait_stream(torch.cuda.current_stream())
1935*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(stream):
1936*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed(5)
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker                g = torch.cuda.CUDAGraph()
1939*da0073e9SAndroid Build Coastguard Worker                torch.cuda.empty_cache()
1940*da0073e9SAndroid Build Coastguard Worker                g.capture_begin()
1941*da0073e9SAndroid Build Coastguard Worker                graph_out = graph_in
1942*da0073e9SAndroid Build Coastguard Worker                for _ in range(2):
1943*da0073e9SAndroid Build Coastguard Worker                    graph_out = op(graph_out, **kwargs)
1944*da0073e9SAndroid Build Coastguard Worker                g.capture_end()
1945*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(stream)
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker            # Runs a graphed->eager->graphed sequence of RNG ops.
1948*da0073e9SAndroid Build Coastguard Worker            # replay() plays 2 invocations of the op, so the sequence has 6
1949*da0073e9SAndroid Build Coastguard Worker            # invocations total, matching Control.
1950*da0073e9SAndroid Build Coastguard Worker            # replay() reads from graph_in and writes to graph_out.
1951*da0073e9SAndroid Build Coastguard Worker            g.replay()
1952*da0073e9SAndroid Build Coastguard Worker            out = op(graph_out, **kwargs)
1953*da0073e9SAndroid Build Coastguard Worker            out = op(out, **kwargs)
1954*da0073e9SAndroid Build Coastguard Worker            graph_in.copy_(out)
1955*da0073e9SAndroid Build Coastguard Worker            g.replay()
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker            # If replay() updated RNG state correctly, graph_out
1958*da0073e9SAndroid Build Coastguard Worker            # should now hold data equal to eager_out.
1959*da0073e9SAndroid Build Coastguard Worker            try:
1960*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(eager_out, graph_out)
1961*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
1962*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Failed on ", op) from e
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker            # Do the same operations varying seeds
1965*da0073e9SAndroid Build Coastguard Worker            seeds = [6, 128, 9999]
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker            for seed in seeds:
1968*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed(seed)
1969*da0073e9SAndroid Build Coastguard Worker                graph_in.copy_(a)
1970*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
1971*da0073e9SAndroid Build Coastguard Worker                    g.replay()
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker                # If the random seed was not updated then the graph would
1974*da0073e9SAndroid Build Coastguard Worker                # generate the same output as in previous check.
1975*da0073e9SAndroid Build Coastguard Worker                try:
1976*da0073e9SAndroid Build Coastguard Worker                    self.assertNotEqual(eager_out, graph_out)
1977*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
1978*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("Failed on ", op) from e
1979*da0073e9SAndroid Build Coastguard Worker
1980*da0073e9SAndroid Build Coastguard Worker                # Now repeat the same operations in non-graphed mode.
1981*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed(seed)
1982*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
1983*da0073e9SAndroid Build Coastguard Worker                    eager_out.copy_(a)
1984*da0073e9SAndroid Build Coastguard Worker                    eager_out = op(eager_out, **kwargs)
1985*da0073e9SAndroid Build Coastguard Worker                    eager_out = op(eager_out, **kwargs)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker                # In the end, graph_out and eager_out must be equal
1988*da0073e9SAndroid Build Coastguard Worker                # as they went under the same set of operations.
1989*da0073e9SAndroid Build Coastguard Worker                try:
1990*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(eager_out, graph_out)
1991*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
1992*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("Failed on ", op) from e
1993*da0073e9SAndroid Build Coastguard Worker
1994*da0073e9SAndroid Build Coastguard Worker            # We hold references to all tensors used across streams up til this sync,
1995*da0073e9SAndroid Build Coastguard Worker            # so no need to call record_stream on those tensors.
1996*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1997*da0073e9SAndroid Build Coastguard Worker
1998*da0073e9SAndroid Build Coastguard Worker        for op, kwargs in ops_with_kwargs:
1999*da0073e9SAndroid Build Coastguard Worker            run(op, kwargs)
2000*da0073e9SAndroid Build Coastguard Worker
2001*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2002*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2003*da0073e9SAndroid Build Coastguard Worker    )
2004*da0073e9SAndroid Build Coastguard Worker    def test_graph_rng_distributions(self):
2005*da0073e9SAndroid Build Coastguard Worker        size = 10000
2006*da0073e9SAndroid Build Coastguard Worker        input = torch.rand((size,), device="cuda", dtype=torch.float)
2007*da0073e9SAndroid Build Coastguard Worker        alloc = torch.empty((size,), device="cuda", dtype=torch.float)
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker        # Torch ops to test with sample args (tuple) and kwargs (dict)
2010*da0073e9SAndroid Build Coastguard Worker        torch_with_args = (
2011*da0073e9SAndroid Build Coastguard Worker            ("bernoulli", (input.clone(),), {}),
2012*da0073e9SAndroid Build Coastguard Worker            # multinomial uses some uncapturable CUDA calls.
2013*da0073e9SAndroid Build Coastguard Worker            # TODO: reenable multinomial tests if/when the implementation is capturable.
2014*da0073e9SAndroid Build Coastguard Worker            # ("multinomial", (input.clone(), size, True), {}),
2015*da0073e9SAndroid Build Coastguard Worker            # ("multinomial", (input.clone(), size // 2, False), {}),
2016*da0073e9SAndroid Build Coastguard Worker            # TODO: reenable normal test, where std is a device
2017*da0073e9SAndroid Build Coastguard Worker            # tensor, when graph test failures are fixed
2018*da0073e9SAndroid Build Coastguard Worker            # ("normal", (input.clone() + 1, input.clone()), {}),
2019*da0073e9SAndroid Build Coastguard Worker            ("normal", (input.clone() + 1, 1.0), {}),
2020*da0073e9SAndroid Build Coastguard Worker            ("poisson", (input.clone(),), {}),
2021*da0073e9SAndroid Build Coastguard Worker            ("rand", (size,), {"device": "cuda", "dtype": torch.float}),
2022*da0073e9SAndroid Build Coastguard Worker            ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}),
2023*da0073e9SAndroid Build Coastguard Worker            ("randn", (size,), {"device": "cuda", "dtype": torch.float}),
2024*da0073e9SAndroid Build Coastguard Worker        )
2025*da0073e9SAndroid Build Coastguard Worker
2026*da0073e9SAndroid Build Coastguard Worker        # Tensor methods to test with sample args (tuple)
2027*da0073e9SAndroid Build Coastguard Worker        tensor_with_args = (
2028*da0073e9SAndroid Build Coastguard Worker            ("bernoulli_", (input.clone(),)),
2029*da0073e9SAndroid Build Coastguard Worker            ("cauchy_", ()),
2030*da0073e9SAndroid Build Coastguard Worker            ("exponential_", ()),
2031*da0073e9SAndroid Build Coastguard Worker            ("geometric_", (0.3,)),
2032*da0073e9SAndroid Build Coastguard Worker            ("log_normal_", ()),
2033*da0073e9SAndroid Build Coastguard Worker            ("normal_", ()),
2034*da0073e9SAndroid Build Coastguard Worker            ("random_", ()),
2035*da0073e9SAndroid Build Coastguard Worker            ("uniform_", ()),
2036*da0073e9SAndroid Build Coastguard Worker        )
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Worker        def run(module, op, args, kwargs):
2039*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(5)
2040*da0073e9SAndroid Build Coastguard Worker
2041*da0073e9SAndroid Build Coastguard Worker            # Each path runs a dummy op to increment the state a bit before creating controls.
2042*da0073e9SAndroid Build Coastguard Worker            if module == "torch":
2043*da0073e9SAndroid Build Coastguard Worker                dummy = getattr(torch, op)(*args, **kwargs)
2044*da0073e9SAndroid Build Coastguard Worker                control1 = getattr(torch, op)(*args, **kwargs)
2045*da0073e9SAndroid Build Coastguard Worker                control2 = getattr(torch, op)(*args, **kwargs)
2046*da0073e9SAndroid Build Coastguard Worker            else:
2047*da0073e9SAndroid Build Coastguard Worker                dummy = alloc.clone()
2048*da0073e9SAndroid Build Coastguard Worker                control1 = alloc.clone()
2049*da0073e9SAndroid Build Coastguard Worker                control2 = alloc.clone()
2050*da0073e9SAndroid Build Coastguard Worker                getattr(dummy, op)(*args)
2051*da0073e9SAndroid Build Coastguard Worker                getattr(control1, op)(*args)
2052*da0073e9SAndroid Build Coastguard Worker                getattr(control2, op)(*args)
2053*da0073e9SAndroid Build Coastguard Worker
2054*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.Stream()
2055*da0073e9SAndroid Build Coastguard Worker            stream.wait_stream(torch.cuda.current_stream())
2056*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(stream):
2057*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed(5)
2058*da0073e9SAndroid Build Coastguard Worker
2059*da0073e9SAndroid Build Coastguard Worker                g = torch.cuda.CUDAGraph()
2060*da0073e9SAndroid Build Coastguard Worker                torch.cuda.empty_cache()
2061*da0073e9SAndroid Build Coastguard Worker                if module == "torch":
2062*da0073e9SAndroid Build Coastguard Worker                    g.capture_begin()
2063*da0073e9SAndroid Build Coastguard Worker                    t1 = getattr(torch, op)(*args, **kwargs)
2064*da0073e9SAndroid Build Coastguard Worker                    t2 = getattr(torch, op)(*args, **kwargs)
2065*da0073e9SAndroid Build Coastguard Worker                    g.capture_end()
2066*da0073e9SAndroid Build Coastguard Worker                else:
2067*da0073e9SAndroid Build Coastguard Worker                    t1 = alloc.clone()
2068*da0073e9SAndroid Build Coastguard Worker                    t2 = alloc.clone()
2069*da0073e9SAndroid Build Coastguard Worker                    g.capture_begin()
2070*da0073e9SAndroid Build Coastguard Worker                    getattr(t1, op)(*args)
2071*da0073e9SAndroid Build Coastguard Worker                    getattr(t2, op)(*args)
2072*da0073e9SAndroid Build Coastguard Worker                    g.capture_end()
2073*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(stream)
2074*da0073e9SAndroid Build Coastguard Worker
2075*da0073e9SAndroid Build Coastguard Worker            if not TEST_CUDAMALLOCASYNC:
2076*da0073e9SAndroid Build Coastguard Worker                # Makes sure values haven't been populated yet
2077*da0073e9SAndroid Build Coastguard Worker                # (in other words, makes sure capture didn't actually run ops).
2078*da0073e9SAndroid Build Coastguard Worker                # We can only try this with the native allocator, for which captured
2079*da0073e9SAndroid Build Coastguard Worker                # addresses are already backed by cudaMalloced memory.
2080*da0073e9SAndroid Build Coastguard Worker                # If we try it with cudaMallocAsync, CUDA won't event consider
2081*da0073e9SAndroid Build Coastguard Worker                # the captured addresses allocated until replay(), and if we
2082*da0073e9SAndroid Build Coastguard Worker                # access them before replay() we get IMAs.
2083*da0073e9SAndroid Build Coastguard Worker                try:
2084*da0073e9SAndroid Build Coastguard Worker                    self.assertNotEqual(control1, t1)
2085*da0073e9SAndroid Build Coastguard Worker                    self.assertNotEqual(control2, t2)
2086*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
2087*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("Failed on " + module + "." + op) from e
2088*da0073e9SAndroid Build Coastguard Worker
2089*da0073e9SAndroid Build Coastguard Worker            # Set a new seed to check if graph would use it
2090*da0073e9SAndroid Build Coastguard Worker            for seed in [6, 314, 271]:
2091*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed(seed)
2092*da0073e9SAndroid Build Coastguard Worker                # Runs a dummy op prelude, as for controls, to make sure replay()
2093*da0073e9SAndroid Build Coastguard Worker                # picks up the dummy op's state increment.
2094*da0073e9SAndroid Build Coastguard Worker                if module == "torch":
2095*da0073e9SAndroid Build Coastguard Worker                    dummy = getattr(torch, op)(*args, **kwargs)
2096*da0073e9SAndroid Build Coastguard Worker                    control1 = getattr(torch, op)(*args, **kwargs)
2097*da0073e9SAndroid Build Coastguard Worker                    control2 = getattr(torch, op)(*args, **kwargs)
2098*da0073e9SAndroid Build Coastguard Worker                else:
2099*da0073e9SAndroid Build Coastguard Worker                    getattr(dummy, op)(*args)
2100*da0073e9SAndroid Build Coastguard Worker                    getattr(control1, op)(*args)
2101*da0073e9SAndroid Build Coastguard Worker                    getattr(control2, op)(*args)
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed(seed)
2104*da0073e9SAndroid Build Coastguard Worker                if module == "torch":
2105*da0073e9SAndroid Build Coastguard Worker                    dummy = getattr(torch, op)(*args, **kwargs)
2106*da0073e9SAndroid Build Coastguard Worker                else:
2107*da0073e9SAndroid Build Coastguard Worker                    getattr(dummy, op)(*args)
2108*da0073e9SAndroid Build Coastguard Worker
2109*da0073e9SAndroid Build Coastguard Worker                # see above comment on TEST_CUDAMALLOCASYNC
2110*da0073e9SAndroid Build Coastguard Worker                if not TEST_CUDAMALLOCASYNC:
2111*da0073e9SAndroid Build Coastguard Worker                    t1.copy_(alloc)
2112*da0073e9SAndroid Build Coastguard Worker                    t2.copy_(alloc)
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker                # Runs RNG ops that fill t1 and t2.
2115*da0073e9SAndroid Build Coastguard Worker                g.replay()
2116*da0073e9SAndroid Build Coastguard Worker
2117*da0073e9SAndroid Build Coastguard Worker                try:
2118*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(control1, t1)
2119*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(control2, t2)
2120*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
2121*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("Failed on " + module + "." + op) from e
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker            # We hold references to all tensors used across streams up til this sync,
2124*da0073e9SAndroid Build Coastguard Worker            # so no need to call record_stream on those tensors.
2125*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker        for op_with_args in torch_with_args:
2128*da0073e9SAndroid Build Coastguard Worker            run("torch", *op_with_args)
2129*da0073e9SAndroid Build Coastguard Worker
2130*da0073e9SAndroid Build Coastguard Worker        for meth_with_args in tensor_with_args:
2131*da0073e9SAndroid Build Coastguard Worker            # Adds an empty dict for kwargs, which none of the Tensor methods use
2132*da0073e9SAndroid Build Coastguard Worker            run("Tensor", *(meth_with_args + ({},)))
2133*da0073e9SAndroid Build Coastguard Worker
2134*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2135*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2136*da0073e9SAndroid Build Coastguard Worker    )
2137*da0073e9SAndroid Build Coastguard Worker    def test_graph_two_successive(self):
2138*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker        size = 1000
2141*da0073e9SAndroid Build Coastguard Worker        kSmallBuffer = 2097152
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker        def func_with_temps(t, val):
2144*da0073e9SAndroid Build Coastguard Worker            x = t.clone() + val
2145*da0073e9SAndroid Build Coastguard Worker            y = t.clone() + val
2146*da0073e9SAndroid Build Coastguard Worker            return x + y
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
2149*da0073e9SAndroid Build Coastguard Worker
2150*da0073e9SAndroid Build Coastguard Worker        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2151*da0073e9SAndroid Build Coastguard Worker            g0 = torch.cuda.CUDAGraph()
2152*da0073e9SAndroid Build Coastguard Worker            g1 = torch.cuda.CUDAGraph()
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker            a = torch.ones((size,), device="cuda")
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Worker            s.wait_stream(torch.cuda.current_stream())
2157*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
2158*da0073e9SAndroid Build Coastguard Worker                g0_args = (
2159*da0073e9SAndroid Build Coastguard Worker                    (torch.cuda.graph_pool_handle(),)
2160*da0073e9SAndroid Build Coastguard Worker                    if share_mem == "via graph_pool_handle()"
2161*da0073e9SAndroid Build Coastguard Worker                    else ()
2162*da0073e9SAndroid Build Coastguard Worker                )
2163*da0073e9SAndroid Build Coastguard Worker                g0.capture_begin(*g0_args)
2164*da0073e9SAndroid Build Coastguard Worker                b = a.clone()
2165*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2166*da0073e9SAndroid Build Coastguard Worker                    b = func_with_temps(b, 1)
2167*da0073e9SAndroid Build Coastguard Worker                g0.capture_end()
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker                g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2170*da0073e9SAndroid Build Coastguard Worker                g1.capture_begin(*g1_args)
2171*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2172*da0073e9SAndroid Build Coastguard Worker                    b = func_with_temps(b, 1)
2173*da0073e9SAndroid Build Coastguard Worker                g1.capture_end()
2174*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(s)
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker            # mixes unrelated eager ops with replays
2177*da0073e9SAndroid Build Coastguard Worker            c = a.clone()
2178*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
2179*da0073e9SAndroid Build Coastguard Worker                c = func_with_temps(c, 3)
2180*da0073e9SAndroid Build Coastguard Worker            g0.replay()
2181*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
2182*da0073e9SAndroid Build Coastguard Worker                c = func_with_temps(c, 3)
2183*da0073e9SAndroid Build Coastguard Worker            g1.replay()
2184*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
2185*da0073e9SAndroid Build Coastguard Worker                c = func_with_temps(c, 3)
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b.sum().item(), size * 3070)
2188*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(c.sum().item(), size * 442)
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker            if not TEST_CUDAMALLOCASYNC:
2191*da0073e9SAndroid Build Coastguard Worker                # These stat checks are specific to the native allocator.
2192*da0073e9SAndroid Build Coastguard Worker                if share_mem != "Don't share":
2193*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2194*da0073e9SAndroid Build Coastguard Worker                        reserved_no_sharing  # noqa: F821
2195*da0073e9SAndroid Build Coastguard Worker                        - torch.cuda.memory_stats()["reserved_bytes.all.current"],
2196*da0073e9SAndroid Build Coastguard Worker                        kSmallBuffer,
2197*da0073e9SAndroid Build Coastguard Worker                    )
2198*da0073e9SAndroid Build Coastguard Worker                else:
2199*da0073e9SAndroid Build Coastguard Worker                    reserved_no_sharing = torch.cuda.memory_stats()[
2200*da0073e9SAndroid Build Coastguard Worker                        "reserved_bytes.all.current"
2201*da0073e9SAndroid Build Coastguard Worker                    ]
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Worker            del a, b, c, g0, g1
2204*da0073e9SAndroid Build Coastguard Worker            # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
2205*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
2206*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
2207*da0073e9SAndroid Build Coastguard Worker
2208*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2209*da0073e9SAndroid Build Coastguard Worker        (not TEST_CUDA_GRAPH)
2210*da0073e9SAndroid Build Coastguard Worker        or IS_WINDOWS
2211*da0073e9SAndroid Build Coastguard Worker        or (  # appears to still be broken on Windows as of 11.4+
2212*da0073e9SAndroid Build Coastguard Worker            torch.version.cuda
2213*da0073e9SAndroid Build Coastguard Worker            and int(torch.version.cuda.split(".")[0]) == 11
2214*da0073e9SAndroid Build Coastguard Worker            and int(torch.version.cuda.split(".")[1]) < 4
2215*da0073e9SAndroid Build Coastguard Worker        ),
2216*da0073e9SAndroid Build Coastguard Worker        "Graph bindings disallow concurrent replay for CUDA < 11.4, see "
2217*da0073e9SAndroid Build Coastguard Worker        + "https://github.com/pytorch/pytorch/pull/57556",
2218*da0073e9SAndroid Build Coastguard Worker    )
2219*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2220*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2221*da0073e9SAndroid Build Coastguard Worker    )
2222*da0073e9SAndroid Build Coastguard Worker    def test_graph_concurrent_replay(self):
2223*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
2224*da0073e9SAndroid Build Coastguard Worker
2225*da0073e9SAndroid Build Coastguard Worker        size = 1000000  # largeish to help expose race conditions
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker        def func_with_temps(t, val):
2228*da0073e9SAndroid Build Coastguard Worker            x = t.clone() + val
2229*da0073e9SAndroid Build Coastguard Worker            y = t.clone() + val
2230*da0073e9SAndroid Build Coastguard Worker            return x + y
2231*da0073e9SAndroid Build Coastguard Worker
2232*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2235*da0073e9SAndroid Build Coastguard Worker            g0 = torch.cuda.CUDAGraph()
2236*da0073e9SAndroid Build Coastguard Worker            g1 = torch.cuda.CUDAGraph()
2237*da0073e9SAndroid Build Coastguard Worker
2238*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.Stream()
2239*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.Stream()
2240*da0073e9SAndroid Build Coastguard Worker
2241*da0073e9SAndroid Build Coastguard Worker            a = torch.ones((size,), device="cuda")
2242*da0073e9SAndroid Build Coastguard Worker
2243*da0073e9SAndroid Build Coastguard Worker            s.wait_stream(torch.cuda.current_stream())
2244*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
2245*da0073e9SAndroid Build Coastguard Worker                g0_args = (
2246*da0073e9SAndroid Build Coastguard Worker                    (torch.cuda.graph_pool_handle(),)
2247*da0073e9SAndroid Build Coastguard Worker                    if share_mem == "via graph_pool_handle()"
2248*da0073e9SAndroid Build Coastguard Worker                    else ()
2249*da0073e9SAndroid Build Coastguard Worker                )
2250*da0073e9SAndroid Build Coastguard Worker                g0.capture_begin(*g0_args)
2251*da0073e9SAndroid Build Coastguard Worker                b = a.clone()
2252*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2253*da0073e9SAndroid Build Coastguard Worker                    b = func_with_temps(b, 1)
2254*da0073e9SAndroid Build Coastguard Worker                g0.capture_end()
2255*da0073e9SAndroid Build Coastguard Worker
2256*da0073e9SAndroid Build Coastguard Worker                g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2257*da0073e9SAndroid Build Coastguard Worker                g1.capture_begin(*g1_args)
2258*da0073e9SAndroid Build Coastguard Worker                c = a.clone()
2259*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2260*da0073e9SAndroid Build Coastguard Worker                    c = func_with_temps(c, 2)
2261*da0073e9SAndroid Build Coastguard Worker                g1.capture_end()
2262*da0073e9SAndroid Build Coastguard Worker
2263*da0073e9SAndroid Build Coastguard Worker            # To reproduce data corruption, I need g0 and g1's kernels to run concurrently.
2264*da0073e9SAndroid Build Coastguard Worker            # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead.
2265*da0073e9SAndroid Build Coastguard Worker            # The following pattern helps align device-side execution of g0 and g1's kernels.
2266*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
2267*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s0):
2268*da0073e9SAndroid Build Coastguard Worker                torch.cuda._sleep(1000000)
2269*da0073e9SAndroid Build Coastguard Worker                s1.wait_stream(s0)
2270*da0073e9SAndroid Build Coastguard Worker                g0.replay()
2271*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s1):
2272*da0073e9SAndroid Build Coastguard Worker                g1.replay()
2273*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(s0)
2274*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(s1)
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker            if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"):
2277*da0073e9SAndroid Build Coastguard Worker                # If we used the native allocator and shared mempools,
2278*da0073e9SAndroid Build Coastguard Worker                # we expect the concurrent replays corrupted each other.
2279*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(b.sum().item(), size * 94)
2280*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(c.sum().item(), size * 156)
2281*da0073e9SAndroid Build Coastguard Worker            else:
2282*da0073e9SAndroid Build Coastguard Worker                # If we EITHER
2283*da0073e9SAndroid Build Coastguard Worker                #   - used the native allocator without sharing mempools, OR
2284*da0073e9SAndroid Build Coastguard Worker                #   - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
2285*da0073e9SAndroid Build Coastguard Worker                # we don't expect memory corruption.
2286*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.sum().item(), size * 94)
2287*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c.sum().item(), size * 156)
2288*da0073e9SAndroid Build Coastguard Worker
2289*da0073e9SAndroid Build Coastguard Worker            del a, b, c, g0, g1
2290*da0073e9SAndroid Build Coastguard Worker            # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them.
2291*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
2292*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
2293*da0073e9SAndroid Build Coastguard Worker
2294*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2295*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2296*da0073e9SAndroid Build Coastguard Worker    )
2297*da0073e9SAndroid Build Coastguard Worker    def test_graph_three_successive(self):
2298*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
2299*da0073e9SAndroid Build Coastguard Worker
2300*da0073e9SAndroid Build Coastguard Worker        size = 1000
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
2303*da0073e9SAndroid Build Coastguard Worker
2304*da0073e9SAndroid Build Coastguard Worker        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2305*da0073e9SAndroid Build Coastguard Worker            a = torch.ones((size,), device="cuda")
2306*da0073e9SAndroid Build Coastguard Worker
2307*da0073e9SAndroid Build Coastguard Worker            g0 = torch.cuda.CUDAGraph()
2308*da0073e9SAndroid Build Coastguard Worker            g1 = torch.cuda.CUDAGraph()
2309*da0073e9SAndroid Build Coastguard Worker            g2 = torch.cuda.CUDAGraph()
2310*da0073e9SAndroid Build Coastguard Worker
2311*da0073e9SAndroid Build Coastguard Worker            s.wait_stream(torch.cuda.current_stream())
2312*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
2313*da0073e9SAndroid Build Coastguard Worker                g0_args = (
2314*da0073e9SAndroid Build Coastguard Worker                    (torch.cuda.graph_pool_handle(),)
2315*da0073e9SAndroid Build Coastguard Worker                    if share_mem == "via graph_pool_handle()"
2316*da0073e9SAndroid Build Coastguard Worker                    else ()
2317*da0073e9SAndroid Build Coastguard Worker                )
2318*da0073e9SAndroid Build Coastguard Worker                g0.capture_begin(*g0_args)
2319*da0073e9SAndroid Build Coastguard Worker                b = a.clone()
2320*da0073e9SAndroid Build Coastguard Worker                c = b + 1
2321*da0073e9SAndroid Build Coastguard Worker                d = b + 2
2322*da0073e9SAndroid Build Coastguard Worker                g0.capture_end()
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker                args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker                g1.capture_begin(*args)
2327*da0073e9SAndroid Build Coastguard Worker                e = c + 3
2328*da0073e9SAndroid Build Coastguard Worker                del c
2329*da0073e9SAndroid Build Coastguard Worker                g1.capture_end()
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker                g2.capture_begin(*args)
2332*da0073e9SAndroid Build Coastguard Worker                f = d + 4
2333*da0073e9SAndroid Build Coastguard Worker                g2.capture_end()
2334*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(s)
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker            # Tests that replaying in capture order is valid
2337*da0073e9SAndroid Build Coastguard Worker            g0.replay()
2338*da0073e9SAndroid Build Coastguard Worker            g1.replay()
2339*da0073e9SAndroid Build Coastguard Worker            g2.replay()
2340*da0073e9SAndroid Build Coastguard Worker
2341*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.sum().item(), size * 5)
2342*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f.sum().item(), size * 7)
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker            # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool
2345*da0073e9SAndroid Build Coastguard Worker            g0.replay()
2346*da0073e9SAndroid Build Coastguard Worker            g2.replay()
2347*da0073e9SAndroid Build Coastguard Worker            g1.replay()
2348*da0073e9SAndroid Build Coastguard Worker
2349*da0073e9SAndroid Build Coastguard Worker            expect_corruption = (not TEST_CUDAMALLOCASYNC) and (
2350*da0073e9SAndroid Build Coastguard Worker                share_mem != "Don't share"
2351*da0073e9SAndroid Build Coastguard Worker            )
2352*da0073e9SAndroid Build Coastguard Worker            # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
2353*da0073e9SAndroid Build Coastguard Worker            # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
2354*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2355*da0073e9SAndroid Build Coastguard Worker                e.sum().item(), size * (7 + 3) if expect_corruption else size * 5
2356*da0073e9SAndroid Build Coastguard Worker            )
2357*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f.sum().item(), size * 7)
2358*da0073e9SAndroid Build Coastguard Worker
2359*da0073e9SAndroid Build Coastguard Worker            del a, b, d, e, f, g0, g1, g2
2360*da0073e9SAndroid Build Coastguard Worker            # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them.
2361*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
2362*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
2363*da0073e9SAndroid Build Coastguard Worker
2364*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2365*da0073e9SAndroid Build Coastguard Worker        (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC,
2366*da0073e9SAndroid Build Coastguard Worker        "CUDA >= 11.0 or ROCM >= 5.3 required for graphs",
2367*da0073e9SAndroid Build Coastguard Worker    )
2368*da0073e9SAndroid Build Coastguard Worker    def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
2369*da0073e9SAndroid Build Coastguard Worker        kSmallSize = 1048576
2370*da0073e9SAndroid Build Coastguard Worker        kSmallBuffer = 2097152
2371*da0073e9SAndroid Build Coastguard Worker        kLargeBuffer = 20971520
2372*da0073e9SAndroid Build Coastguard Worker        kMinLargeAlloc = 10485760
2373*da0073e9SAndroid Build Coastguard Worker        kRoundLarge = 2097152
2374*da0073e9SAndroid Build Coastguard Worker
2375*da0073e9SAndroid Build Coastguard Worker        elem = 4
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Worker        # this was annoying to write but stresses the expectations pretty rigorously
2378*da0073e9SAndroid Build Coastguard Worker        cases = (
2379*da0073e9SAndroid Build Coastguard Worker            (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"),
2380*da0073e9SAndroid Build Coastguard Worker            (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"),
2381*da0073e9SAndroid Build Coastguard Worker            ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"),
2382*da0073e9SAndroid Build Coastguard Worker            (
2383*da0073e9SAndroid Build Coastguard Worker                (kMinLargeAlloc - 512) // elem,
2384*da0073e9SAndroid Build Coastguard Worker                2,
2385*da0073e9SAndroid Build Coastguard Worker                2 * kLargeBuffer,
2386*da0073e9SAndroid Build Coastguard Worker                kLargeBuffer,
2387*da0073e9SAndroid Build Coastguard Worker                "large_pool",
2388*da0073e9SAndroid Build Coastguard Worker            ),
2389*da0073e9SAndroid Build Coastguard Worker            (
2390*da0073e9SAndroid Build Coastguard Worker                (kMinLargeAlloc + 512) // elem,
2391*da0073e9SAndroid Build Coastguard Worker                3,
2392*da0073e9SAndroid Build Coastguard Worker                3
2393*da0073e9SAndroid Build Coastguard Worker                * (
2394*da0073e9SAndroid Build Coastguard Worker                    kRoundLarge
2395*da0073e9SAndroid Build Coastguard Worker                    * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge)
2396*da0073e9SAndroid Build Coastguard Worker                ),
2397*da0073e9SAndroid Build Coastguard Worker                kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge),
2398*da0073e9SAndroid Build Coastguard Worker                "large_pool",
2399*da0073e9SAndroid Build Coastguard Worker            ),
2400*da0073e9SAndroid Build Coastguard Worker        )
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker        stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.")
2403*da0073e9SAndroid Build Coastguard Worker
2404*da0073e9SAndroid Build Coastguard Worker        gc.collect()
2405*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
2406*da0073e9SAndroid Build Coastguard Worker
2407*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
2408*da0073e9SAndroid Build Coastguard Worker
2409*da0073e9SAndroid Build Coastguard Worker        for (
2410*da0073e9SAndroid Build Coastguard Worker            numel,
2411*da0073e9SAndroid Build Coastguard Worker            delta_cudaMallocs,
2412*da0073e9SAndroid Build Coastguard Worker            delta_cudaMalloc_bytes,
2413*da0073e9SAndroid Build Coastguard Worker            delta_cudaMalloc_bytes_post_del_g,
2414*da0073e9SAndroid Build Coastguard Worker            pool_string,
2415*da0073e9SAndroid Build Coastguard Worker        ) in cases:
2416*da0073e9SAndroid Build Coastguard Worker            if pool_string == "small_pool":
2417*da0073e9SAndroid Build Coastguard Worker                delta_active_blocks = 3  # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders
2418*da0073e9SAndroid Build Coastguard Worker                delta_active_bytes = (
2419*da0073e9SAndroid Build Coastguard Worker                    numel * elem + 1024
2420*da0073e9SAndroid Build Coastguard Worker                )  # + 1024 for CUDAGraph's rng seed and offset holders each
2421*da0073e9SAndroid Build Coastguard Worker            else:
2422*da0073e9SAndroid Build Coastguard Worker                delta_active_blocks = 1  # We only check the large pool, which isn't affected by rng offset holder
2423*da0073e9SAndroid Build Coastguard Worker                delta_active_bytes = numel * elem
2424*da0073e9SAndroid Build Coastguard Worker
2425*da0073e9SAndroid Build Coastguard Worker            g = torch.cuda.CUDAGraph()
2426*da0073e9SAndroid Build Coastguard Worker            s.wait_stream(torch.cuda.current_stream())
2427*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
2428*da0073e9SAndroid Build Coastguard Worker                # Allocation stat estimates assume input is created on the same stream as capture_begin()
2429*da0073e9SAndroid Build Coastguard Worker                # (in other words, the same stream silo as the rng offset holder, which is not allocated from the
2430*da0073e9SAndroid Build Coastguard Worker                # capture's private pool).
2431*da0073e9SAndroid Build Coastguard Worker                a = torch.ones((numel,), device="cuda")
2432*da0073e9SAndroid Build Coastguard Worker
2433*da0073e9SAndroid Build Coastguard Worker                precapture_stats = torch.cuda.memory_stats()
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker                g.capture_begin()
2436*da0073e9SAndroid Build Coastguard Worker                b = a.clone()
2437*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2438*da0073e9SAndroid Build Coastguard Worker                    b = b.clone() + 1
2439*da0073e9SAndroid Build Coastguard Worker                g.capture_end()
2440*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(s)
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker            gc.collect()
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker            postcapture_stats = torch.cuda.memory_stats()
2445*da0073e9SAndroid Build Coastguard Worker
2446*da0073e9SAndroid Build Coastguard Worker            expecteds = (
2447*da0073e9SAndroid Build Coastguard Worker                delta_cudaMallocs,
2448*da0073e9SAndroid Build Coastguard Worker                delta_cudaMalloc_bytes,
2449*da0073e9SAndroid Build Coastguard Worker                delta_active_blocks,
2450*da0073e9SAndroid Build Coastguard Worker                delta_active_bytes,
2451*da0073e9SAndroid Build Coastguard Worker            )
2452*da0073e9SAndroid Build Coastguard Worker            # Double checks replay and stats before and after a call to empty_cache
2453*da0073e9SAndroid Build Coastguard Worker            for i in range(2):
2454*da0073e9SAndroid Build Coastguard Worker                for stat, expected in zip(stats_to_check, expecteds):
2455*da0073e9SAndroid Build Coastguard Worker                    stat = stat + pool_string + ".current"
2456*da0073e9SAndroid Build Coastguard Worker                    current = postcapture_stats[stat] - precapture_stats[stat]
2457*da0073e9SAndroid Build Coastguard Worker
2458*da0073e9SAndroid Build Coastguard Worker                    # There will only ever be one expandable segment in each of the small and large pools. The way the
2459*da0073e9SAndroid Build Coastguard Worker                    # bookeeping is done in the allocator means that we never increment the number of segments.
2460*da0073e9SAndroid Build Coastguard Worker                    if self.expandable_segments and "segment" in stat:
2461*da0073e9SAndroid Build Coastguard Worker                        expected = 0
2462*da0073e9SAndroid Build Coastguard Worker                    # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
2463*da0073e9SAndroid Build Coastguard Worker                    # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
2464*da0073e9SAndroid Build Coastguard Worker                    # smaller than the page size
2465*da0073e9SAndroid Build Coastguard Worker                    if (
2466*da0073e9SAndroid Build Coastguard Worker                        self.expandable_segments
2467*da0073e9SAndroid Build Coastguard Worker                        and "reserved" in stat
2468*da0073e9SAndroid Build Coastguard Worker                        and (numel == cases[3][0] or numel == cases[4][0])
2469*da0073e9SAndroid Build Coastguard Worker                    ):
2470*da0073e9SAndroid Build Coastguard Worker                        expected = 2 * kLargeBuffer
2471*da0073e9SAndroid Build Coastguard Worker
2472*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2473*da0073e9SAndroid Build Coastguard Worker                        current,
2474*da0073e9SAndroid Build Coastguard Worker                        expected,
2475*da0073e9SAndroid Build Coastguard Worker                        "Pre to post capture delta of "
2476*da0073e9SAndroid Build Coastguard Worker                        + stat
2477*da0073e9SAndroid Build Coastguard Worker                        + f" = {current}, expected = {expected}, numel = {numel}",
2478*da0073e9SAndroid Build Coastguard Worker                    )
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker                g.replay()
2481*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.sum().item(), 6 * numel)
2482*da0073e9SAndroid Build Coastguard Worker                if i == 0:
2483*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.empty_cache()
2484*da0073e9SAndroid Build Coastguard Worker
2485*da0073e9SAndroid Build Coastguard Worker            del g
2486*da0073e9SAndroid Build Coastguard Worker            gc.collect()
2487*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
2488*da0073e9SAndroid Build Coastguard Worker            postdel_stats = torch.cuda.memory_stats()
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker            # Uses graph result b after graph has been deleted
2491*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b.sum().item(), 6 * numel)
2492*da0073e9SAndroid Build Coastguard Worker
2493*da0073e9SAndroid Build Coastguard Worker            # b should be the only live reference remaining from the graph's private pool
2494*da0073e9SAndroid Build Coastguard Worker            expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem)
2495*da0073e9SAndroid Build Coastguard Worker            for stat, expected in zip(stats_to_check, expecteds):
2496*da0073e9SAndroid Build Coastguard Worker                stat = stat + pool_string + ".current"
2497*da0073e9SAndroid Build Coastguard Worker                current = postdel_stats[stat] - precapture_stats[stat]
2498*da0073e9SAndroid Build Coastguard Worker
2499*da0073e9SAndroid Build Coastguard Worker                # There will only ever be one expandable segment in each of the small and large pools. The way the
2500*da0073e9SAndroid Build Coastguard Worker                # bookeeping is done in the allocator means that we never increment the number of segments.
2501*da0073e9SAndroid Build Coastguard Worker                if self.expandable_segments and "segment" in stat:
2502*da0073e9SAndroid Build Coastguard Worker                    expected = 0
2503*da0073e9SAndroid Build Coastguard Worker                # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
2504*da0073e9SAndroid Build Coastguard Worker                # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
2505*da0073e9SAndroid Build Coastguard Worker                # smaller than the page size
2506*da0073e9SAndroid Build Coastguard Worker                if (
2507*da0073e9SAndroid Build Coastguard Worker                    self.expandable_segments
2508*da0073e9SAndroid Build Coastguard Worker                    and "reserved" in stat
2509*da0073e9SAndroid Build Coastguard Worker                    and numel == cases[3][0]
2510*da0073e9SAndroid Build Coastguard Worker                ):
2511*da0073e9SAndroid Build Coastguard Worker                    expected = 2 * kLargeBuffer
2512*da0073e9SAndroid Build Coastguard Worker                if (
2513*da0073e9SAndroid Build Coastguard Worker                    self.expandable_segments
2514*da0073e9SAndroid Build Coastguard Worker                    and "reserved" in stat
2515*da0073e9SAndroid Build Coastguard Worker                    and numel == cases[4][0]
2516*da0073e9SAndroid Build Coastguard Worker                ):
2517*da0073e9SAndroid Build Coastguard Worker                    expected = kLargeBuffer
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2520*da0073e9SAndroid Build Coastguard Worker                    current,
2521*da0073e9SAndroid Build Coastguard Worker                    expected,
2522*da0073e9SAndroid Build Coastguard Worker                    "Pre capture to post graph delete delta of "
2523*da0073e9SAndroid Build Coastguard Worker                    + stat
2524*da0073e9SAndroid Build Coastguard Worker                    + f" = {current}, expected = {expected}, numel = {numel}",
2525*da0073e9SAndroid Build Coastguard Worker                )
2526*da0073e9SAndroid Build Coastguard Worker
2527*da0073e9SAndroid Build Coastguard Worker            # del a, b before the next case is essential, otherwise overwriting a and b in the next case
2528*da0073e9SAndroid Build Coastguard Worker            # can throw off its allocation/deallocation counts.
2529*da0073e9SAndroid Build Coastguard Worker            del a, b
2530*da0073e9SAndroid Build Coastguard Worker            # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
2531*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
2532*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
2533*da0073e9SAndroid Build Coastguard Worker
2534*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2535*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2536*da0073e9SAndroid Build Coastguard Worker    )
2537*da0073e9SAndroid Build Coastguard Worker    def test_graph_record_stream(self):
2538*da0073e9SAndroid Build Coastguard Worker        # Makes sure graph capture defers attempting to reclaim allocations used across streams. See
2539*da0073e9SAndroid Build Coastguard Worker        # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp
2540*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
2541*da0073e9SAndroid Build Coastguard Worker
2542*da0073e9SAndroid Build Coastguard Worker        potential_problem = torch.zeros((3,), device="cuda")
2543*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros((3,), device="cuda")
2544*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.Stream()
2545*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream()
2546*da0073e9SAndroid Build Coastguard Worker        s2 = torch.cuda.Stream()
2547*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
2550*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s0):
2551*da0073e9SAndroid Build Coastguard Worker            potential_problem.record_stream(s0)
2552*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES)
2553*da0073e9SAndroid Build Coastguard Worker            potential_problem.fill_(1.0)
2554*da0073e9SAndroid Build Coastguard Worker        del potential_problem
2555*da0073e9SAndroid Build Coastguard Worker
2556*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s1):
2557*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
2558*da0073e9SAndroid Build Coastguard Worker            # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc
2559*da0073e9SAndroid Build Coastguard Worker            # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life
2560*da0073e9SAndroid Build Coastguard Worker            # event, which will cause the capture to error.
2561*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
2562*da0073e9SAndroid Build Coastguard Worker
2563*da0073e9SAndroid Build Coastguard Worker            # Let's also see what happens if we record_stream on a tensor during capture.
2564*da0073e9SAndroid Build Coastguard Worker            s2.wait_stream(s1)
2565*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s2):
2566*da0073e9SAndroid Build Coastguard Worker                b.fill_(1.0)
2567*da0073e9SAndroid Build Coastguard Worker                b.record_stream(s2)  # dummy record_stream
2568*da0073e9SAndroid Build Coastguard Worker                del b
2569*da0073e9SAndroid Build Coastguard Worker            s1.wait_stream(s2)
2570*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
2571*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
2572*da0073e9SAndroid Build Coastguard Worker
2573*da0073e9SAndroid Build Coastguard Worker        # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
2574*da0073e9SAndroid Build Coastguard Worker        c = torch.zeros((3,), device="cuda")
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
2577*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2578*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2579*da0073e9SAndroid Build Coastguard Worker    )
2580*da0073e9SAndroid Build Coastguard Worker    # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize
2581*da0073e9SAndroid Build Coastguard Worker    # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior
2582*da0073e9SAndroid Build Coastguard Worker    # as a memory leak unless we skip the leak check.
2583*da0073e9SAndroid Build Coastguard Worker    @skipCUDAMemoryLeakCheckIf(True)
2584*da0073e9SAndroid Build Coastguard Worker    @serialTest()
2585*da0073e9SAndroid Build Coastguard Worker    def test_graph_cudnn_dropout(self):
2586*da0073e9SAndroid Build Coastguard Worker        # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp.
2587*da0073e9SAndroid Build Coastguard Worker        # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should
2588*da0073e9SAndroid Build Coastguard Worker        # avoid syncing noncapturing streams with captured events or vice versa.
2589*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
2590*da0073e9SAndroid Build Coastguard Worker
2591*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
2592*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(100, 192, 512, device="cuda")
2593*da0073e9SAndroid Build Coastguard Worker
2594*da0073e9SAndroid Build Coastguard Worker        y = model(x)
2595*da0073e9SAndroid Build Coastguard Worker
2596*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
2597*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
2598*da0073e9SAndroid Build Coastguard Worker        s.wait_stream(torch.cuda.current_stream())
2599*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
2600*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
2601*da0073e9SAndroid Build Coastguard Worker            y = model(x)
2602*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
2603*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
2604*da0073e9SAndroid Build Coastguard Worker
2605*da0073e9SAndroid Build Coastguard Worker        g.replay()
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker        y = model(x)
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2610*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2611*da0073e9SAndroid Build Coastguard Worker    )
2612*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2613*da0073e9SAndroid Build Coastguard Worker        "with_amp,cache_enabled,allow_unused_input",
2614*da0073e9SAndroid Build Coastguard Worker        [
2615*da0073e9SAndroid Build Coastguard Worker            subtest((False, False, True), decorators=[skipIfRocm]),
2616*da0073e9SAndroid Build Coastguard Worker            subtest((True, False, True), decorators=[skipIfRocm]),
2617*da0073e9SAndroid Build Coastguard Worker            subtest((True, True, True), decorators=[unittest.expectedFailure]),
2618*da0073e9SAndroid Build Coastguard Worker            subtest((False, False, False), decorators=[unittest.expectedFailure]),
2619*da0073e9SAndroid Build Coastguard Worker        ],
2620*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda x, y, z: "{}{}{}".format(
2621*da0073e9SAndroid Build Coastguard Worker            {True: "with_amp", False: "without_amp"}[x],
2622*da0073e9SAndroid Build Coastguard Worker            {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
2623*da0073e9SAndroid Build Coastguard Worker            {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
2624*da0073e9SAndroid Build Coastguard Worker        ),
2625*da0073e9SAndroid Build Coastguard Worker    )
2626*da0073e9SAndroid Build Coastguard Worker    @serialTest()
2627*da0073e9SAndroid Build Coastguard Worker    def test_graph_make_graphed_callables(
2628*da0073e9SAndroid Build Coastguard Worker        self, with_amp, cache_enabled, allow_unused_input
2629*da0073e9SAndroid Build Coastguard Worker    ):
2630*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(5)
2631*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed(5)
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        N, D_in, H, D_out = 640, 4096, 2048, 1024
2634*da0073e9SAndroid Build Coastguard Worker
2635*da0073e9SAndroid Build Coastguard Worker        class MLP1(torch.nn.Module):
2636*da0073e9SAndroid Build Coastguard Worker            def __init__(self, D_in: int, H: int, D_out: int):
2637*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2638*da0073e9SAndroid Build Coastguard Worker                self.net_1 = torch.nn.Sequential(
2639*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
2640*da0073e9SAndroid Build Coastguard Worker                ).cuda()
2641*da0073e9SAndroid Build Coastguard Worker                self.net_2 = torch.nn.Sequential(
2642*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
2643*da0073e9SAndroid Build Coastguard Worker                ).cuda()
2644*da0073e9SAndroid Build Coastguard Worker
2645*da0073e9SAndroid Build Coastguard Worker            def forward(self, input_dict: dict):
2646*da0073e9SAndroid Build Coastguard Worker                x = input_dict["x"]
2647*da0073e9SAndroid Build Coastguard Worker                return self.net_2(self.net_1(x))
2648*da0073e9SAndroid Build Coastguard Worker
2649*da0073e9SAndroid Build Coastguard Worker        class MLP2(torch.nn.Module):
2650*da0073e9SAndroid Build Coastguard Worker            def __init__(self, D_in: int, H: int, D_out: int):
2651*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2652*da0073e9SAndroid Build Coastguard Worker                self.net_1 = torch.nn.Sequential(
2653*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
2654*da0073e9SAndroid Build Coastguard Worker                ).cuda()
2655*da0073e9SAndroid Build Coastguard Worker                self.net_2 = torch.nn.Sequential(
2656*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
2657*da0073e9SAndroid Build Coastguard Worker                ).cuda()
2658*da0073e9SAndroid Build Coastguard Worker
2659*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2660*da0073e9SAndroid Build Coastguard Worker                return self.net_2(self.net_1(x))
2661*da0073e9SAndroid Build Coastguard Worker
2662*da0073e9SAndroid Build Coastguard Worker        class ParameterlessModule(torch.nn.Module):
2663*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2664*da0073e9SAndroid Build Coastguard Worker                idx = (
2665*da0073e9SAndroid Build Coastguard Worker                    torch.arange(x.size(0), device=x.device)
2666*da0073e9SAndroid Build Coastguard Worker                    .view(-1, 1)
2667*da0073e9SAndroid Build Coastguard Worker                    .repeat(1, x.size(1))
2668*da0073e9SAndroid Build Coastguard Worker                )
2669*da0073e9SAndroid Build Coastguard Worker                return {"output": torch.gather(x, 0, idx)}
2670*da0073e9SAndroid Build Coastguard Worker
2671*da0073e9SAndroid Build Coastguard Worker        models = []
2672*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
2673*da0073e9SAndroid Build Coastguard Worker            model_section1 = MLP1(D_in, H, H).cuda()
2674*da0073e9SAndroid Build Coastguard Worker            model_section2 = MLP2(H, H, D_out).cuda()
2675*da0073e9SAndroid Build Coastguard Worker            model_section3 = ParameterlessModule().cuda()
2676*da0073e9SAndroid Build Coastguard Worker            models.append(
2677*da0073e9SAndroid Build Coastguard Worker                torch.nn.Sequential(model_section1, model_section2, model_section3)
2678*da0073e9SAndroid Build Coastguard Worker            )
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker        model_graphed = models[0]
2681*da0073e9SAndroid Build Coastguard Worker        model_control = models[1]
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker        model_graphed.load_state_dict(model_control.state_dict())
2684*da0073e9SAndroid Build Coastguard Worker
2685*da0073e9SAndroid Build Coastguard Worker        opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1)
2686*da0073e9SAndroid Build Coastguard Worker        opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1)
2687*da0073e9SAndroid Build Coastguard Worker
2688*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, D_in, device="cuda")
2689*da0073e9SAndroid Build Coastguard Worker        h = torch.randn(N, H, device="cuda", requires_grad=True)
2690*da0073e9SAndroid Build Coastguard Worker        h2 = torch.randn(N, D_out, device="cuda", requires_grad=True)
2691*da0073e9SAndroid Build Coastguard Worker        unused_input = torch.randn(N, H, device="cuda", requires_grad=True)
2692*da0073e9SAndroid Build Coastguard Worker        y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True)
2693*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(N, D_out, device="cuda")
2694*da0073e9SAndroid Build Coastguard Worker
2695*da0073e9SAndroid Build Coastguard Worker        loss_fn_control = torch.nn.functional.mse_loss
2696*da0073e9SAndroid Build Coastguard Worker        relu_control = torch.nn.functional.relu
2697*da0073e9SAndroid Build Coastguard Worker
2698*da0073e9SAndroid Build Coastguard Worker        # This is a good stress test. It graphs four callables: two Modules and two python functions.
2699*da0073e9SAndroid Build Coastguard Worker        with torch.amp.autocast(
2700*da0073e9SAndroid Build Coastguard Worker            device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2701*da0073e9SAndroid Build Coastguard Worker        ):
2702*da0073e9SAndroid Build Coastguard Worker            (
2703*da0073e9SAndroid Build Coastguard Worker                model_graphed[0],
2704*da0073e9SAndroid Build Coastguard Worker                model_graphed[1],
2705*da0073e9SAndroid Build Coastguard Worker                model_graphed[2],
2706*da0073e9SAndroid Build Coastguard Worker                relu_graphed,
2707*da0073e9SAndroid Build Coastguard Worker                loss_fn_graphed,
2708*da0073e9SAndroid Build Coastguard Worker            ) = torch.cuda.make_graphed_callables(
2709*da0073e9SAndroid Build Coastguard Worker                (
2710*da0073e9SAndroid Build Coastguard Worker                    model_graphed[0],
2711*da0073e9SAndroid Build Coastguard Worker                    model_graphed[1],
2712*da0073e9SAndroid Build Coastguard Worker                    model_graphed[2],
2713*da0073e9SAndroid Build Coastguard Worker                    relu_control,
2714*da0073e9SAndroid Build Coastguard Worker                    loss_fn_control,
2715*da0073e9SAndroid Build Coastguard Worker                ),
2716*da0073e9SAndroid Build Coastguard Worker                (
2717*da0073e9SAndroid Build Coastguard Worker                    ({"x": x, "unused_input": unused_input},),
2718*da0073e9SAndroid Build Coastguard Worker                    (h,),
2719*da0073e9SAndroid Build Coastguard Worker                    (h2,),
2720*da0073e9SAndroid Build Coastguard Worker                    (y_pred,),
2721*da0073e9SAndroid Build Coastguard Worker                    (y_pred, y),
2722*da0073e9SAndroid Build Coastguard Worker                ),
2723*da0073e9SAndroid Build Coastguard Worker                allow_unused_input=allow_unused_input,
2724*da0073e9SAndroid Build Coastguard Worker            )
2725*da0073e9SAndroid Build Coastguard Worker
2726*da0073e9SAndroid Build Coastguard Worker        real_inputs = [torch.rand_like(x) for _ in range(10)]
2727*da0073e9SAndroid Build Coastguard Worker        real_targets = [torch.rand_like(y) for _ in range(10)]
2728*da0073e9SAndroid Build Coastguard Worker
2729*da0073e9SAndroid Build Coastguard Worker        for m, opt, relu, loss_fn in zip(
2730*da0073e9SAndroid Build Coastguard Worker            (model_graphed, model_control),
2731*da0073e9SAndroid Build Coastguard Worker            (opt_graphed, opt_control),
2732*da0073e9SAndroid Build Coastguard Worker            (relu_graphed, relu_control),
2733*da0073e9SAndroid Build Coastguard Worker            (loss_fn_graphed, loss_fn_control),
2734*da0073e9SAndroid Build Coastguard Worker        ):
2735*da0073e9SAndroid Build Coastguard Worker            # Resets RNC states before iterations for graphed and ungraphed models,
2736*da0073e9SAndroid Build Coastguard Worker            # so dropout math should be bitwise identical for both.
2737*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(5)
2738*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(5)
2739*da0073e9SAndroid Build Coastguard Worker            for data, target in zip(real_inputs, real_targets):
2740*da0073e9SAndroid Build Coastguard Worker                opt.zero_grad(set_to_none=True)
2741*da0073e9SAndroid Build Coastguard Worker                with torch.amp.autocast(
2742*da0073e9SAndroid Build Coastguard Worker                    device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2743*da0073e9SAndroid Build Coastguard Worker                ):
2744*da0073e9SAndroid Build Coastguard Worker                    y_pred = m({"x": data, "unused_input": unused_input})["output"]
2745*da0073e9SAndroid Build Coastguard Worker                    y_pred = relu(y_pred)
2746*da0073e9SAndroid Build Coastguard Worker                    loss = loss_fn(y_pred, target)
2747*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
2748*da0073e9SAndroid Build Coastguard Worker                opt.step()
2749*da0073e9SAndroid Build Coastguard Worker
2750*da0073e9SAndroid Build Coastguard Worker        for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
2751*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p, pc)
2752*da0073e9SAndroid Build Coastguard Worker
2753*da0073e9SAndroid Build Coastguard Worker        # We graphed the models in training mode. Eval should still run ungraphed.
2754*da0073e9SAndroid Build Coastguard Worker        model_graphed.eval()
2755*da0073e9SAndroid Build Coastguard Worker        model_control.eval()
2756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2757*da0073e9SAndroid Build Coastguard Worker            model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
2758*da0073e9SAndroid Build Coastguard Worker        )
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2761*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2762*da0073e9SAndroid Build Coastguard Worker    )
2763*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2764*da0073e9SAndroid Build Coastguard Worker        "with_amp,cache_enabled,allow_unused_input",
2765*da0073e9SAndroid Build Coastguard Worker        [
2766*da0073e9SAndroid Build Coastguard Worker            subtest((False, False, True), decorators=[skipIfRocm]),
2767*da0073e9SAndroid Build Coastguard Worker            subtest((True, False, True), decorators=[skipIfRocm]),
2768*da0073e9SAndroid Build Coastguard Worker            subtest((True, True, True), decorators=[unittest.expectedFailure]),
2769*da0073e9SAndroid Build Coastguard Worker            subtest((False, False, False), decorators=[skipIfRocm]),
2770*da0073e9SAndroid Build Coastguard Worker        ],
2771*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda x, y, z: "{}{}{}".format(
2772*da0073e9SAndroid Build Coastguard Worker            {True: "with_amp", False: "without_amp"}[x],
2773*da0073e9SAndroid Build Coastguard Worker            {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
2774*da0073e9SAndroid Build Coastguard Worker            {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
2775*da0073e9SAndroid Build Coastguard Worker        ),
2776*da0073e9SAndroid Build Coastguard Worker    )
2777*da0073e9SAndroid Build Coastguard Worker    @serialTest()
2778*da0073e9SAndroid Build Coastguard Worker    def test_graph_make_graphed_callables_parameterless_nograd_module(
2779*da0073e9SAndroid Build Coastguard Worker        self, with_amp, cache_enabled, allow_unused_input
2780*da0073e9SAndroid Build Coastguard Worker    ):
2781*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(5)
2782*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed(5)
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker        N, D_in, H, D_out = 640, 4096, 2048, 1024
2785*da0073e9SAndroid Build Coastguard Worker
2786*da0073e9SAndroid Build Coastguard Worker        class ParameterlessModule(torch.nn.Module):
2787*da0073e9SAndroid Build Coastguard Worker            def forward(self, input_dict: dict):
2788*da0073e9SAndroid Build Coastguard Worker                x = input_dict["x"]
2789*da0073e9SAndroid Build Coastguard Worker                idx = (
2790*da0073e9SAndroid Build Coastguard Worker                    torch.arange(x.size(0), device=x.device)
2791*da0073e9SAndroid Build Coastguard Worker                    .view(-1, 1)
2792*da0073e9SAndroid Build Coastguard Worker                    .repeat(1, x.size(1))
2793*da0073e9SAndroid Build Coastguard Worker                )
2794*da0073e9SAndroid Build Coastguard Worker                return {"output": torch.gather(x, 0, idx)}
2795*da0073e9SAndroid Build Coastguard Worker
2796*da0073e9SAndroid Build Coastguard Worker        models = []
2797*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
2798*da0073e9SAndroid Build Coastguard Worker            model_section1 = ParameterlessModule().cuda()
2799*da0073e9SAndroid Build Coastguard Worker            models.append(torch.nn.Sequential(model_section1))
2800*da0073e9SAndroid Build Coastguard Worker
2801*da0073e9SAndroid Build Coastguard Worker        model_graphed = models[0]
2802*da0073e9SAndroid Build Coastguard Worker        model_control = models[1]
2803*da0073e9SAndroid Build Coastguard Worker
2804*da0073e9SAndroid Build Coastguard Worker        model_graphed.load_state_dict(model_control.state_dict())
2805*da0073e9SAndroid Build Coastguard Worker
2806*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, D_in, device="cuda", requires_grad=False)
2807*da0073e9SAndroid Build Coastguard Worker        unused_input = torch.randn(N, H, device="cuda", requires_grad=False)
2808*da0073e9SAndroid Build Coastguard Worker        y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False)
2809*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(N, D_in, device="cuda")
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker        # This is a good stress test. It graphs four callables: two Modules and two python functions.
2812*da0073e9SAndroid Build Coastguard Worker        with torch.amp.autocast(
2813*da0073e9SAndroid Build Coastguard Worker            device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2814*da0073e9SAndroid Build Coastguard Worker        ):
2815*da0073e9SAndroid Build Coastguard Worker            model_graphed[0] = torch.cuda.make_graphed_callables(
2816*da0073e9SAndroid Build Coastguard Worker                model_graphed[0],
2817*da0073e9SAndroid Build Coastguard Worker                ({"x": x, "unused_input": unused_input},),
2818*da0073e9SAndroid Build Coastguard Worker                allow_unused_input=allow_unused_input,
2819*da0073e9SAndroid Build Coastguard Worker            )
2820*da0073e9SAndroid Build Coastguard Worker
2821*da0073e9SAndroid Build Coastguard Worker        real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)]
2822*da0073e9SAndroid Build Coastguard Worker        real_targets = [torch.rand_like(y) for _ in range(10)]
2823*da0073e9SAndroid Build Coastguard Worker
2824*da0073e9SAndroid Build Coastguard Worker        for m in (model_graphed, model_control):
2825*da0073e9SAndroid Build Coastguard Worker            # Resets RNC states before iterations for graphed and ungraphed models,
2826*da0073e9SAndroid Build Coastguard Worker            # so dropout math should be bitwise identical for both.
2827*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(5)
2828*da0073e9SAndroid Build Coastguard Worker            torch.cuda.manual_seed(5)
2829*da0073e9SAndroid Build Coastguard Worker            for data, _ in zip(real_inputs, real_targets):
2830*da0073e9SAndroid Build Coastguard Worker                with torch.amp.autocast(
2831*da0073e9SAndroid Build Coastguard Worker                    device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2832*da0073e9SAndroid Build Coastguard Worker                ):
2833*da0073e9SAndroid Build Coastguard Worker                    out = m({"x": data, "unused_input": unused_input})["output"]
2834*da0073e9SAndroid Build Coastguard Worker
2835*da0073e9SAndroid Build Coastguard Worker        # We graphed the models in training mode. Eval should still run ungraphed.
2836*da0073e9SAndroid Build Coastguard Worker        model_graphed.eval()
2837*da0073e9SAndroid Build Coastguard Worker        model_control.eval()
2838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2839*da0073e9SAndroid Build Coastguard Worker            model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
2840*da0073e9SAndroid Build Coastguard Worker        )
2841*da0073e9SAndroid Build Coastguard Worker
2842*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2843*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2844*da0073e9SAndroid Build Coastguard Worker    )
2845*da0073e9SAndroid Build Coastguard Worker    def test_graph_make_graphed_callables_same_pool(self):
2846*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(5)
2847*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed(5)
2848*da0073e9SAndroid Build Coastguard Worker        models = []
2849*da0073e9SAndroid Build Coastguard Worker        num_models = 3
2850*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_models):
2851*da0073e9SAndroid Build Coastguard Worker            models.append(
2852*da0073e9SAndroid Build Coastguard Worker                torch.nn.Sequential(
2853*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(32, 128),
2854*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
2855*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(128, 128),
2856*da0073e9SAndroid Build Coastguard Worker                ).cuda()
2857*da0073e9SAndroid Build Coastguard Worker            )
2858*da0073e9SAndroid Build Coastguard Worker        # we will reuse the same pool for all graph captures
2859*da0073e9SAndroid Build Coastguard Worker        mempool = torch.cuda.graph_pool_handle()
2860*da0073e9SAndroid Build Coastguard Worker        graphed_models = []
2861*da0073e9SAndroid Build Coastguard Worker        for model in models:
2862*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([64, 32], device="cuda")
2863*da0073e9SAndroid Build Coastguard Worker            graphed_model = deepcopy(model)
2864*da0073e9SAndroid Build Coastguard Worker            graphed_model = torch.cuda.make_graphed_callables(
2865*da0073e9SAndroid Build Coastguard Worker                graphed_model, (x,), pool=mempool
2866*da0073e9SAndroid Build Coastguard Worker            )
2867*da0073e9SAndroid Build Coastguard Worker            graphed_models.append(graphed_model)
2868*da0073e9SAndroid Build Coastguard Worker
2869*da0073e9SAndroid Build Coastguard Worker        for model, graphed_model in zip(models, graphed_models):
2870*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([64, 32], device="cuda")
2871*da0073e9SAndroid Build Coastguard Worker            y = model(x)
2872*da0073e9SAndroid Build Coastguard Worker            yg = graphed_model(x)
2873*da0073e9SAndroid Build Coastguard Worker            l = y.norm()
2874*da0073e9SAndroid Build Coastguard Worker            lg = yg.norm()
2875*da0073e9SAndroid Build Coastguard Worker            l.backward()
2876*da0073e9SAndroid Build Coastguard Worker            lg.backward()
2877*da0073e9SAndroid Build Coastguard Worker
2878*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, yg)
2879*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l, lg)
2880*da0073e9SAndroid Build Coastguard Worker            for p, pg in zip(model.parameters(), graphed_model.parameters()):
2881*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p, pg)
2882*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p.grad, pg.grad)
2883*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(p.data_ptr(), pg.data_ptr())
2884*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr())
2885*da0073e9SAndroid Build Coastguard Worker
2886*da0073e9SAndroid Build Coastguard Worker    def _test_graphed_optimizer(
2887*da0073e9SAndroid Build Coastguard Worker        self, steps_warmup, steps_train, optimizer_ctor, kwargs
2888*da0073e9SAndroid Build Coastguard Worker    ):
2889*da0073e9SAndroid Build Coastguard Worker        for actually_do_graphs in (True, False):
2890*da0073e9SAndroid Build Coastguard Worker            params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [
2891*da0073e9SAndroid Build Coastguard Worker                torch.randn((), device="cuda")
2892*da0073e9SAndroid Build Coastguard Worker            ]
2893*da0073e9SAndroid Build Coastguard Worker            params_control = [p.clone().requires_grad_() for p in params]
2894*da0073e9SAndroid Build Coastguard Worker            params_graphed = [p.clone().requires_grad_() for p in params]
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker            grads = [
2897*da0073e9SAndroid Build Coastguard Worker                [torch.randn_like(p) for p in params]
2898*da0073e9SAndroid Build Coastguard Worker                for _ in range(steps_warmup + steps_train)
2899*da0073e9SAndroid Build Coastguard Worker            ]
2900*da0073e9SAndroid Build Coastguard Worker
2901*da0073e9SAndroid Build Coastguard Worker            # Control (capturable=False)
2902*da0073e9SAndroid Build Coastguard Worker
2903*da0073e9SAndroid Build Coastguard Worker            opt = optimizer_ctor(params_control, capturable=False, **kwargs)
2904*da0073e9SAndroid Build Coastguard Worker
2905*da0073e9SAndroid Build Coastguard Worker            for i in range(steps_warmup + steps_train):
2906*da0073e9SAndroid Build Coastguard Worker                for j, p in enumerate(params_control):
2907*da0073e9SAndroid Build Coastguard Worker                    p.grad = grads[i][j]
2908*da0073e9SAndroid Build Coastguard Worker                opt.step()
2909*da0073e9SAndroid Build Coastguard Worker
2910*da0073e9SAndroid Build Coastguard Worker            # capturable=True
2911*da0073e9SAndroid Build Coastguard Worker
2912*da0073e9SAndroid Build Coastguard Worker            opt = optimizer_ctor(params_graphed, capturable=True, **kwargs)
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker            for i in range(steps_warmup):
2915*da0073e9SAndroid Build Coastguard Worker                for j, p in enumerate(params_graphed):
2916*da0073e9SAndroid Build Coastguard Worker                    p.grad = grads[i][j]
2917*da0073e9SAndroid Build Coastguard Worker                opt.step()
2918*da0073e9SAndroid Build Coastguard Worker
2919*da0073e9SAndroid Build Coastguard Worker            if actually_do_graphs:
2920*da0073e9SAndroid Build Coastguard Worker                g = torch.cuda.CUDAGraph()
2921*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.graph(g):
2922*da0073e9SAndroid Build Coastguard Worker                    opt.step()
2923*da0073e9SAndroid Build Coastguard Worker
2924*da0073e9SAndroid Build Coastguard Worker            for i in range(steps_train):
2925*da0073e9SAndroid Build Coastguard Worker                if actually_do_graphs:
2926*da0073e9SAndroid Build Coastguard Worker                    for j, p in enumerate(params_graphed):
2927*da0073e9SAndroid Build Coastguard Worker                        p.grad.copy_(grads[i + steps_warmup][j])
2928*da0073e9SAndroid Build Coastguard Worker                    g.replay()
2929*da0073e9SAndroid Build Coastguard Worker                else:
2930*da0073e9SAndroid Build Coastguard Worker                    # Passing capturable=True to the constructor and running without graphs should still be
2931*da0073e9SAndroid Build Coastguard Worker                    # numerically correct, even if it's not ideal for performance.
2932*da0073e9SAndroid Build Coastguard Worker                    for j, p in enumerate(params_graphed):
2933*da0073e9SAndroid Build Coastguard Worker                        p.grad = grads[i + steps_warmup][j]
2934*da0073e9SAndroid Build Coastguard Worker                    opt.step()
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Worker            for p_control, p_graphed in zip(params_control, params_graphed):
2937*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p_control, p_graphed)
2938*da0073e9SAndroid Build Coastguard Worker
2939*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2940*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2941*da0073e9SAndroid Build Coastguard Worker    )
2942*da0073e9SAndroid Build Coastguard Worker    def test_graph_optims_with_explicitly_capturable_param_groups(self):
2943*da0073e9SAndroid Build Coastguard Worker        # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__
2944*da0073e9SAndroid Build Coastguard Worker        n_warmup, n_replay = 3, 2
2945*da0073e9SAndroid Build Coastguard Worker        for optimizer, second_param_group_capturable in product(
2946*da0073e9SAndroid Build Coastguard Worker            (
2947*da0073e9SAndroid Build Coastguard Worker                torch.optim.Adam,
2948*da0073e9SAndroid Build Coastguard Worker                torch.optim.AdamW,
2949*da0073e9SAndroid Build Coastguard Worker                torch.optim.ASGD,
2950*da0073e9SAndroid Build Coastguard Worker                torch.optim.Adamax,
2951*da0073e9SAndroid Build Coastguard Worker                torch.optim.NAdam,
2952*da0073e9SAndroid Build Coastguard Worker                torch.optim.RAdam,
2953*da0073e9SAndroid Build Coastguard Worker                torch.optim.Adadelta,
2954*da0073e9SAndroid Build Coastguard Worker                torch.optim.RMSprop,
2955*da0073e9SAndroid Build Coastguard Worker                torch.optim.Rprop,
2956*da0073e9SAndroid Build Coastguard Worker            ),
2957*da0073e9SAndroid Build Coastguard Worker            (True, False),
2958*da0073e9SAndroid Build Coastguard Worker        ):
2959*da0073e9SAndroid Build Coastguard Worker            ref_p1, param1 = (
2960*da0073e9SAndroid Build Coastguard Worker                torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
2961*da0073e9SAndroid Build Coastguard Worker            )
2962*da0073e9SAndroid Build Coastguard Worker            ref_p2, param2 = (
2963*da0073e9SAndroid Build Coastguard Worker                torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
2964*da0073e9SAndroid Build Coastguard Worker            )
2965*da0073e9SAndroid Build Coastguard Worker            grads1, grads2 = (
2966*da0073e9SAndroid Build Coastguard Worker                [torch.randn_like(param1) for _ in range(n_warmup + n_replay)]
2967*da0073e9SAndroid Build Coastguard Worker                for _ in range(2)
2968*da0073e9SAndroid Build Coastguard Worker            )
2969*da0073e9SAndroid Build Coastguard Worker            ref_grads1, ref_grads2 = (
2970*da0073e9SAndroid Build Coastguard Worker                [t.clone() for t in tensors] for tensors in (grads1, grads2)
2971*da0073e9SAndroid Build Coastguard Worker            )
2972*da0073e9SAndroid Build Coastguard Worker            params = [
2973*da0073e9SAndroid Build Coastguard Worker                {"params": [param1], "capturable": True},
2974*da0073e9SAndroid Build Coastguard Worker                {"params": [param2], "capturable": second_param_group_capturable},
2975*da0073e9SAndroid Build Coastguard Worker            ]
2976*da0073e9SAndroid Build Coastguard Worker            opt = optimizer(params)
2977*da0073e9SAndroid Build Coastguard Worker            opt_ = optimizer(
2978*da0073e9SAndroid Build Coastguard Worker                [
2979*da0073e9SAndroid Build Coastguard Worker                    {"params": [ref_p1], "capturable": False},
2980*da0073e9SAndroid Build Coastguard Worker                    {"params": [ref_p2], "capturable": False},
2981*da0073e9SAndroid Build Coastguard Worker                ]
2982*da0073e9SAndroid Build Coastguard Worker            )
2983*da0073e9SAndroid Build Coastguard Worker
2984*da0073e9SAndroid Build Coastguard Worker            for i in range(n_warmup + n_replay):
2985*da0073e9SAndroid Build Coastguard Worker                ref_p1.grad = ref_grads1[i]
2986*da0073e9SAndroid Build Coastguard Worker                ref_p2.grad = ref_grads2[i]
2987*da0073e9SAndroid Build Coastguard Worker                opt_.step()
2988*da0073e9SAndroid Build Coastguard Worker
2989*da0073e9SAndroid Build Coastguard Worker            for i in range(n_warmup):
2990*da0073e9SAndroid Build Coastguard Worker                param1.grad = grads1[i]
2991*da0073e9SAndroid Build Coastguard Worker                param2.grad = grads2[i]
2992*da0073e9SAndroid Build Coastguard Worker                opt.step()
2993*da0073e9SAndroid Build Coastguard Worker
2994*da0073e9SAndroid Build Coastguard Worker            g = torch.cuda.CUDAGraph()
2995*da0073e9SAndroid Build Coastguard Worker            if not second_param_group_capturable:
2996*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"):
2997*da0073e9SAndroid Build Coastguard Worker                    with torch.cuda.graph(g):
2998*da0073e9SAndroid Build Coastguard Worker                        opt.step()
2999*da0073e9SAndroid Build Coastguard Worker            else:
3000*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.graph(g):
3001*da0073e9SAndroid Build Coastguard Worker                    opt.step()
3002*da0073e9SAndroid Build Coastguard Worker
3003*da0073e9SAndroid Build Coastguard Worker                for i in range(n_replay):
3004*da0073e9SAndroid Build Coastguard Worker                    param1.grad.copy_(grads1[n_warmup + i])
3005*da0073e9SAndroid Build Coastguard Worker                    param2.grad.copy_(grads2[n_warmup + i])
3006*da0073e9SAndroid Build Coastguard Worker                    g.replay()
3007*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref_p1, param1)
3008*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref_p2, param2)
3009*da0073e9SAndroid Build Coastguard Worker
3010*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3011*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3012*da0073e9SAndroid Build Coastguard Worker    )
3013*da0073e9SAndroid Build Coastguard Worker    def test_cuda_graph_error_options(self):
3014*da0073e9SAndroid Build Coastguard Worker        def fn():
3015*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros([2000], device="cuda")
3016*da0073e9SAndroid Build Coastguard Worker            y = x + x + x
3017*da0073e9SAndroid Build Coastguard Worker            return y
3018*da0073e9SAndroid Build Coastguard Worker
3019*da0073e9SAndroid Build Coastguard Worker        mem = None
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker        def raw_malloc():
3022*da0073e9SAndroid Build Coastguard Worker            global mem
3023*da0073e9SAndroid Build Coastguard Worker            mem = None
3024*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.Stream()
3025*da0073e9SAndroid Build Coastguard Worker            try:
3026*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(stream):
3027*da0073e9SAndroid Build Coastguard Worker                    mem = torch.cuda.caching_allocator_alloc(1024)
3028*da0073e9SAndroid Build Coastguard Worker            except BaseException:
3029*da0073e9SAndroid Build Coastguard Worker                if mem is None:
3030*da0073e9SAndroid Build Coastguard Worker                    return
3031*da0073e9SAndroid Build Coastguard Worker            try:
3032*da0073e9SAndroid Build Coastguard Worker                torch.cuda.caching_allocator_delete(mem)
3033*da0073e9SAndroid Build Coastguard Worker                mem = None
3034*da0073e9SAndroid Build Coastguard Worker                return None
3035*da0073e9SAndroid Build Coastguard Worker            except BaseException:
3036*da0073e9SAndroid Build Coastguard Worker                pass
3037*da0073e9SAndroid Build Coastguard Worker
3038*da0073e9SAndroid Build Coastguard Worker        def throws_on_cuda_event(capture_error_mode):
3039*da0073e9SAndroid Build Coastguard Worker            graph = torch.cuda.CUDAGraph()
3040*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
3041*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.Stream()
3042*da0073e9SAndroid Build Coastguard Worker            stream.wait_stream(torch.cuda.current_stream())
3043*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(stream):
3044*da0073e9SAndroid Build Coastguard Worker                fn()
3045*da0073e9SAndroid Build Coastguard Worker            stream.synchronize()
3046*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream().wait_stream(stream)
3047*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
3048*da0073e9SAndroid Build Coastguard Worker            try:
3049*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.graph(
3050*da0073e9SAndroid Build Coastguard Worker                    graph, stream=stream, capture_error_mode=capture_error_mode
3051*da0073e9SAndroid Build Coastguard Worker                ):
3052*da0073e9SAndroid Build Coastguard Worker                    out = fn()
3053*da0073e9SAndroid Build Coastguard Worker                    thread = threading.Thread(target=raw_malloc)
3054*da0073e9SAndroid Build Coastguard Worker                    thread.start()
3055*da0073e9SAndroid Build Coastguard Worker                    thread.join()
3056*da0073e9SAndroid Build Coastguard Worker            except Exception:
3057*da0073e9SAndroid Build Coastguard Worker                if mem is not None:
3058*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.caching_allocator_delete(mem)
3059*da0073e9SAndroid Build Coastguard Worker                return True
3060*da0073e9SAndroid Build Coastguard Worker
3061*da0073e9SAndroid Build Coastguard Worker            return False
3062*da0073e9SAndroid Build Coastguard Worker
3063*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(throws_on_cuda_event("thread_local"))
3064*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(throws_on_cuda_event("relaxed"))
3065*da0073e9SAndroid Build Coastguard Worker
3066*da0073e9SAndroid Build Coastguard Worker        # Exception would Corrupt Process and make other tests fail
3067*da0073e9SAndroid Build Coastguard Worker        # self.assertTrue(throws_on_cuda_event("global"))
3068*da0073e9SAndroid Build Coastguard Worker
3069*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3070*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3071*da0073e9SAndroid Build Coastguard Worker    )
3072*da0073e9SAndroid Build Coastguard Worker    def test_cuda_graph_allocator_propagates_stream(self):
3073*da0073e9SAndroid Build Coastguard Worker        segments = torch.cuda.memory_snapshot()
3074*da0073e9SAndroid Build Coastguard Worker        existing_pools = {s["segment_pool_id"] for s in segments}
3075*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10240000, device="cuda")
3076*da0073e9SAndroid Build Coastguard Worker        y = torch.rand_like(x)
3077*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
3078*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.Stream()
3079*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream()
3080*da0073e9SAndroid Build Coastguard Worker        s0.wait_stream(torch.cuda.current_stream())
3081*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s0):
3082*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
3083*da0073e9SAndroid Build Coastguard Worker            z = x + y
3084*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s1):
3085*da0073e9SAndroid Build Coastguard Worker            s1.wait_stream(s0)
3086*da0073e9SAndroid Build Coastguard Worker            w = z + y
3087*da0073e9SAndroid Build Coastguard Worker        s0.wait_stream(s1)
3088*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s0):
3089*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
3090*da0073e9SAndroid Build Coastguard Worker        segments = torch.cuda.memory_snapshot()
3091*da0073e9SAndroid Build Coastguard Worker        x = [
3092*da0073e9SAndroid Build Coastguard Worker            s["segment_pool_id"]
3093*da0073e9SAndroid Build Coastguard Worker            for s in segments
3094*da0073e9SAndroid Build Coastguard Worker            if s["segment_pool_id"] not in existing_pools
3095*da0073e9SAndroid Build Coastguard Worker        ]
3096*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(x), 2)
3097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[0], x[1])
3098*da0073e9SAndroid Build Coastguard Worker
3099*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_gather_stats(self):
3100*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 3, 3, 3, device="cuda")
3101*da0073e9SAndroid Build Coastguard Worker        mean, invstd = torch.batch_norm_gather_stats(
3102*da0073e9SAndroid Build Coastguard Worker            input,
3103*da0073e9SAndroid Build Coastguard Worker            mean=torch.ones(2, 3, device="cuda"),
3104*da0073e9SAndroid Build Coastguard Worker            invstd=torch.ones(2, 3, device="cuda"),
3105*da0073e9SAndroid Build Coastguard Worker            running_mean=None,
3106*da0073e9SAndroid Build Coastguard Worker            running_var=None,
3107*da0073e9SAndroid Build Coastguard Worker            momentum=0.1,
3108*da0073e9SAndroid Build Coastguard Worker            eps=1e-5,
3109*da0073e9SAndroid Build Coastguard Worker            count=2,
3110*da0073e9SAndroid Build Coastguard Worker        )
3111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mean, torch.ones(3, device="cuda"))
3112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(invstd, torch.ones(3, device="cuda"))
3113*da0073e9SAndroid Build Coastguard Worker
3114*da0073e9SAndroid Build Coastguard Worker    def test_matmul_memory_use(self):
3115*da0073e9SAndroid Build Coastguard Worker        def get_max_used():
3116*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
3117*da0073e9SAndroid Build Coastguard Worker            val = torch.cuda.max_memory_allocated()
3118*da0073e9SAndroid Build Coastguard Worker            torch.cuda.reset_peak_memory_stats()
3119*da0073e9SAndroid Build Coastguard Worker            return val
3120*da0073e9SAndroid Build Coastguard Worker
3121*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, 32, 32, device="cuda")
3122*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(24, 32, 1, device="cuda")
3123*da0073e9SAndroid Build Coastguard Worker
3124*da0073e9SAndroid Build Coastguard Worker        get_max_used()
3125*da0073e9SAndroid Build Coastguard Worker
3126*da0073e9SAndroid Build Coastguard Worker        torch.matmul(a, b)
3127*da0073e9SAndroid Build Coastguard Worker
3128*da0073e9SAndroid Build Coastguard Worker        matmul_mem = get_max_used()
3129*da0073e9SAndroid Build Coastguard Worker
3130*da0073e9SAndroid Build Coastguard Worker        a = a.expand(24, 32, 32)
3131*da0073e9SAndroid Build Coastguard Worker        torch.matmul(a, b)
3132*da0073e9SAndroid Build Coastguard Worker
3133*da0073e9SAndroid Build Coastguard Worker        matmul_expand_mem = get_max_used()
3134*da0073e9SAndroid Build Coastguard Worker
3135*da0073e9SAndroid Build Coastguard Worker        torch.bmm(a, b)
3136*da0073e9SAndroid Build Coastguard Worker
3137*da0073e9SAndroid Build Coastguard Worker        bmm_mem = get_max_used()
3138*da0073e9SAndroid Build Coastguard Worker
3139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(matmul_expand_mem, matmul_mem)
3140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bmm_mem, matmul_mem)
3141*da0073e9SAndroid Build Coastguard Worker
3142*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test")
3143*da0073e9SAndroid Build Coastguard Worker    def test_rocm_backward_pass_guard(self):
3144*da0073e9SAndroid Build Coastguard Worker        # The test exercises a ROCm-specific feature.
3145*da0073e9SAndroid Build Coastguard Worker
3146*da0073e9SAndroid Build Coastguard Worker        class MyFunction(torch.autograd.Function):
3147*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3148*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, tensor, constant):
3149*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch._C._rocm_is_backward_pass())
3150*da0073e9SAndroid Build Coastguard Worker                ctx.constant = constant
3151*da0073e9SAndroid Build Coastguard Worker                return tensor * constant
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3154*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3155*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch._C._rocm_is_backward_pass())
3156*da0073e9SAndroid Build Coastguard Worker                return grad_output * ctx.constant, None
3157*da0073e9SAndroid Build Coastguard Worker
3158*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
3159*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3160*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3161*da0073e9SAndroid Build Coastguard Worker                self.a = torch.nn.Parameter(torch.randn(()))
3162*da0073e9SAndroid Build Coastguard Worker
3163*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3164*da0073e9SAndroid Build Coastguard Worker                return MyFunction.apply(x, self.a)
3165*da0073e9SAndroid Build Coastguard Worker
3166*da0073e9SAndroid Build Coastguard Worker        model = MyModule()
3167*da0073e9SAndroid Build Coastguard Worker        criterion = torch.nn.MSELoss(reduction="sum")
3168*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
3169*da0073e9SAndroid Build Coastguard Worker
3170*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
3171*da0073e9SAndroid Build Coastguard Worker        result = model(x)
3172*da0073e9SAndroid Build Coastguard Worker        loss = criterion(result, x)
3173*da0073e9SAndroid Build Coastguard Worker        optimizer.zero_grad()
3174*da0073e9SAndroid Build Coastguard Worker        loss.backward()
3175*da0073e9SAndroid Build Coastguard Worker        optimizer.step()
3176*da0073e9SAndroid Build Coastguard Worker
3177*da0073e9SAndroid Build Coastguard Worker    def test_matmul_device_mismatch(self):
3178*da0073e9SAndroid Build Coastguard Worker        cpu = torch.rand((10, 10))
3179*da0073e9SAndroid Build Coastguard Worker        cuda = cpu.cuda()
3180*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3181*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
3182*da0073e9SAndroid Build Coastguard Worker        ):
3183*da0073e9SAndroid Build Coastguard Worker            cpu @ cuda
3184*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3185*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
3186*da0073e9SAndroid Build Coastguard Worker        ):
3187*da0073e9SAndroid Build Coastguard Worker            cuda @ cpu
3188*da0073e9SAndroid Build Coastguard Worker
3189*da0073e9SAndroid Build Coastguard Worker        for s, m1, m2 in product((cpu, cuda), repeat=3):
3190*da0073e9SAndroid Build Coastguard Worker            if s.device == m1.device == m2.device:
3191*da0073e9SAndroid Build Coastguard Worker                torch.addmm(s, m1, m2)
3192*da0073e9SAndroid Build Coastguard Worker            else:
3193*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
3194*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "Expected all tensors to be on the same device"
3195*da0073e9SAndroid Build Coastguard Worker                ):
3196*da0073e9SAndroid Build Coastguard Worker                    torch.addmm(s, m1, m2)
3197*da0073e9SAndroid Build Coastguard Worker
3198*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient")
3199*da0073e9SAndroid Build Coastguard Worker    def test_lazy_init(self):
3200*da0073e9SAndroid Build Coastguard Worker        """Validate that no CUDA calls are made during `import torch` call"""
3201*da0073e9SAndroid Build Coastguard Worker
3202*da0073e9SAndroid Build Coastguard Worker        def check_output(script: str) -> str:
3203*da0073e9SAndroid Build Coastguard Worker            return (
3204*da0073e9SAndroid Build Coastguard Worker                subprocess.check_output([sys.executable, "-c", script])
3205*da0073e9SAndroid Build Coastguard Worker                .decode("ascii")
3206*da0073e9SAndroid Build Coastguard Worker                .strip()
3207*da0073e9SAndroid Build Coastguard Worker            )
3208*da0073e9SAndroid Build Coastguard Worker
3209*da0073e9SAndroid Build Coastguard Worker        VISIBLE_DEVICES = (
3210*da0073e9SAndroid Build Coastguard Worker            "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES"
3211*da0073e9SAndroid Build Coastguard Worker        )
3212*da0073e9SAndroid Build Coastguard Worker        test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())"
3213*da0073e9SAndroid Build Coastguard Worker        rc = check_output(test_script)
3214*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rc, "0")
3215*da0073e9SAndroid Build Coastguard Worker        if not TEST_WITH_ROCM:
3216*da0073e9SAndroid Build Coastguard Worker            # Check that `cuInit` was not called during the import
3217*da0073e9SAndroid Build Coastguard Worker            # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3
3218*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/issues/116276 for more details
3219*da0073e9SAndroid Build Coastguard Worker            libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll"
3220*da0073e9SAndroid Build Coastguard Worker            cuda_driver_api_call = (
3221*da0073e9SAndroid Build Coastguard Worker                f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))"
3222*da0073e9SAndroid Build Coastguard Worker            )
3223*da0073e9SAndroid Build Coastguard Worker            rc = check_output(
3224*da0073e9SAndroid Build Coastguard Worker                f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})"
3225*da0073e9SAndroid Build Coastguard Worker            )
3226*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rc, "3")
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
3229*da0073e9SAndroid Build Coastguard Worker    def test_hip_device_count(self):
3230*da0073e9SAndroid Build Coastguard Worker        """Validate device_count works with both CUDA/HIP visible devices"""
3231*da0073e9SAndroid Build Coastguard Worker        test_script = """\
3232*da0073e9SAndroid Build Coastguard Workerimport torch
3233*da0073e9SAndroid Build Coastguard Workerimport os
3234*da0073e9SAndroid Build Coastguard Workerprint(f"{torch.cuda.device_count()}")
3235*da0073e9SAndroid Build Coastguard Worker"""
3236*da0073e9SAndroid Build Coastguard Worker        custom_envs = [
3237*da0073e9SAndroid Build Coastguard Worker            {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
3238*da0073e9SAndroid Build Coastguard Worker            {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
3239*da0073e9SAndroid Build Coastguard Worker            {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
3240*da0073e9SAndroid Build Coastguard Worker        ]
3241*da0073e9SAndroid Build Coastguard Worker
3242*da0073e9SAndroid Build Coastguard Worker        for env_config in custom_envs:
3243*da0073e9SAndroid Build Coastguard Worker            env = os.environ.copy()
3244*da0073e9SAndroid Build Coastguard Worker            for key, value in env_config.items():
3245*da0073e9SAndroid Build Coastguard Worker                if value is None:
3246*da0073e9SAndroid Build Coastguard Worker                    env.pop(key, None)
3247*da0073e9SAndroid Build Coastguard Worker                else:
3248*da0073e9SAndroid Build Coastguard Worker                    env[key] = value
3249*da0073e9SAndroid Build Coastguard Worker            r = (
3250*da0073e9SAndroid Build Coastguard Worker                subprocess.check_output([sys.executable, "-c", test_script], env=env)
3251*da0073e9SAndroid Build Coastguard Worker                .decode("ascii")
3252*da0073e9SAndroid Build Coastguard Worker                .strip()
3253*da0073e9SAndroid Build Coastguard Worker            )
3254*da0073e9SAndroid Build Coastguard Worker            self.assertEqual("1", r)
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices")
3257*da0073e9SAndroid Build Coastguard Worker    def test_device_count_not_cached_pre_init(self):
3258*da0073e9SAndroid Build Coastguard Worker        visible_devices = (
3259*da0073e9SAndroid Build Coastguard Worker            "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES"
3260*da0073e9SAndroid Build Coastguard Worker        )
3261*da0073e9SAndroid Build Coastguard Worker        test_script = f"""\
3262*da0073e9SAndroid Build Coastguard Workerimport torch
3263*da0073e9SAndroid Build Coastguard Workerimport os
3264*da0073e9SAndroid Build Coastguard Workerr1 = torch.cuda.device_count()
3265*da0073e9SAndroid Build Coastguard Workeros.environ['{visible_devices}'] = '0'
3266*da0073e9SAndroid Build Coastguard Workerr2 = torch.cuda.device_count()
3267*da0073e9SAndroid Build Coastguard Workertorch.empty(10, device='cuda')
3268*da0073e9SAndroid Build Coastguard Workerprint(f"{{r1}}, {{r2}}")
3269*da0073e9SAndroid Build Coastguard Worker"""
3270*da0073e9SAndroid Build Coastguard Worker
3271*da0073e9SAndroid Build Coastguard Worker        r = (
3272*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output([sys.executable, "-c", test_script])
3273*da0073e9SAndroid Build Coastguard Worker            .decode("ascii")
3274*da0073e9SAndroid Build Coastguard Worker            .strip()
3275*da0073e9SAndroid Build Coastguard Worker        )
3276*da0073e9SAndroid Build Coastguard Worker
3277*da0073e9SAndroid Build Coastguard Worker        x = torch.cuda.device_count()
3278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f"{x}, 1", r)
3279*da0073e9SAndroid Build Coastguard Worker
3280*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
3281*da0073e9SAndroid Build Coastguard Worker    def test_gds_fails_in_ci(self):
3282*da0073e9SAndroid Build Coastguard Worker        if IS_WINDOWS or TEST_WITH_ROCM:
3283*da0073e9SAndroid Build Coastguard Worker            error_msg = "is not supported on this platform"
3284*da0073e9SAndroid Build Coastguard Worker        else:
3285*da0073e9SAndroid Build Coastguard Worker            error_msg = "cuFileHandleRegister failed"
3286*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as f:
3287*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error_msg):
3288*da0073e9SAndroid Build Coastguard Worker                file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker
3291*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
3292*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest
3293*da0073e9SAndroid Build Coastguard Workerclass TestCudaMallocAsync(TestCase):
3294*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3295*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3296*da0073e9SAndroid Build Coastguard Worker    )
3297*da0073e9SAndroid Build Coastguard Worker    def test_memory_snapshot(self):
3298*da0073e9SAndroid Build Coastguard Worker        try:
3299*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory.empty_cache()
3300*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history("state", stacks="python")
3301*da0073e9SAndroid Build Coastguard Worker            # make x the second block in a segment
3302*da0073e9SAndroid Build Coastguard Worker            torch.rand(2 * 311, 411, device="cuda")
3303*da0073e9SAndroid Build Coastguard Worker            unused = torch.rand(310, 410, device="cuda")
3304*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(311, 411, device="cuda")
3305*da0073e9SAndroid Build Coastguard Worker
3306*da0073e9SAndroid Build Coastguard Worker            # create a bunch of tensors that all will tile into the
3307*da0073e9SAndroid Build Coastguard Worker            # same segment to  exercise the history merging code
3308*da0073e9SAndroid Build Coastguard Worker            # 512B is the minimum block size,
3309*da0073e9SAndroid Build Coastguard Worker            # so we allocate all the tensors to this size to make sure
3310*da0073e9SAndroid Build Coastguard Worker            # they tile evenly
3311*da0073e9SAndroid Build Coastguard Worker            tensors = [torch.rand(128, device="cuda") for _ in range(1000)]
3312*da0073e9SAndroid Build Coastguard Worker            while tensors:
3313*da0073e9SAndroid Build Coastguard Worker                del tensors[randint(0, len(tensors) - 1)]
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker            # exercise the history trimming code
3316*da0073e9SAndroid Build Coastguard Worker            torch.rand(128 * 5, device="cuda")
3317*da0073e9SAndroid Build Coastguard Worker
3318*da0073e9SAndroid Build Coastguard Worker            ss = torch.cuda.memory._snapshot()
3319*da0073e9SAndroid Build Coastguard Worker            found_it = False
3320*da0073e9SAndroid Build Coastguard Worker            for seg in ss["segments"]:
3321*da0073e9SAndroid Build Coastguard Worker                self.assertTrue("frames" in seg)
3322*da0073e9SAndroid Build Coastguard Worker                for b in seg["blocks"]:
3323*da0073e9SAndroid Build Coastguard Worker                    if b["requested_size"] == 311 * 411 * 4:
3324*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("test_cuda" in b["frames"][0]["filename"])
3325*da0073e9SAndroid Build Coastguard Worker                        found_it = True
3326*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(x.untyped_storage().data_ptr(), b["address"])
3327*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(found_it)
3328*da0073e9SAndroid Build Coastguard Worker
3329*da0073e9SAndroid Build Coastguard Worker            if not IS_WINDOWS:
3330*da0073e9SAndroid Build Coastguard Worker                with tempfile.NamedTemporaryFile() as f:
3331*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.memory._save_segment_usage(f.name)
3332*da0073e9SAndroid Build Coastguard Worker                    with open(f.name) as f2:
3333*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("test_cuda.py" in f2.read())
3334*da0073e9SAndroid Build Coastguard Worker            del unused
3335*da0073e9SAndroid Build Coastguard Worker            del x
3336*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
3337*da0073e9SAndroid Build Coastguard Worker            ss = torch.cuda.memory._snapshot()
3338*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
3339*da0073e9SAndroid Build Coastguard Worker                ss["device_traces"][0][-1]["action"]
3340*da0073e9SAndroid Build Coastguard Worker                in ("segment_free", "segment_unmap")
3341*da0073e9SAndroid Build Coastguard Worker            )
3342*da0073e9SAndroid Build Coastguard Worker
3343*da0073e9SAndroid Build Coastguard Worker        finally:
3344*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(None)
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
3347*da0073e9SAndroid Build Coastguard Worker    def test_direct_traceback(self):
3348*da0073e9SAndroid Build Coastguard Worker        from torch._C._profiler import gather_traceback, symbolize_tracebacks
3349*da0073e9SAndroid Build Coastguard Worker
3350*da0073e9SAndroid Build Coastguard Worker        c = gather_traceback(True, True, True)
3351*da0073e9SAndroid Build Coastguard Worker        (r,) = symbolize_tracebacks([c])
3352*da0073e9SAndroid Build Coastguard Worker        r = str(r)
3353*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("test_cuda.py" in r)
3354*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("unwind" in r)
3355*da0073e9SAndroid Build Coastguard Worker
3356*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3357*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3358*da0073e9SAndroid Build Coastguard Worker    )
3359*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3360*da0073e9SAndroid Build Coastguard Worker    def test_memory_snapshot_with_cpp(self):
3361*da0073e9SAndroid Build Coastguard Worker        try:
3362*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory.empty_cache()
3363*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history("state", stacks="all")
3364*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(311, 411, device="cuda")
3365*da0073e9SAndroid Build Coastguard Worker
3366*da0073e9SAndroid Build Coastguard Worker            ss = torch.cuda.memory._snapshot()["segments"]
3367*da0073e9SAndroid Build Coastguard Worker            found_it = False
3368*da0073e9SAndroid Build Coastguard Worker            for seg in ss:
3369*da0073e9SAndroid Build Coastguard Worker                for b in seg["blocks"]:
3370*da0073e9SAndroid Build Coastguard Worker                    if b["requested_size"] == 311 * 411 * 4:
3371*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("::rand" in str(b["frames"]))
3372*da0073e9SAndroid Build Coastguard Worker                        found_it = True
3373*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(found_it)
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker        finally:
3376*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(None)
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
3379*da0073e9SAndroid Build Coastguard Worker    def test_memory_profiler_viz(self):
3380*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(
3381*da0073e9SAndroid Build Coastguard Worker            with_stack=True, profile_memory=True, record_shapes=True
3382*da0073e9SAndroid Build Coastguard Worker        ) as prof:
3383*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(128, 128, device="cuda")
3384*da0073e9SAndroid Build Coastguard Worker            x * x + x * x
3385*da0073e9SAndroid Build Coastguard Worker        plot = profile_plot(prof)
3386*da0073e9SAndroid Build Coastguard Worker        plot = json.dumps(_profile_to_snapshot(prof))
3387*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("test_cuda.py" in plot)
3388*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("test_memory_profiler_viz" in plot)
3389*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("category" in plot)
3390*da0073e9SAndroid Build Coastguard Worker
3391*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3392*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3393*da0073e9SAndroid Build Coastguard Worker    )
3394*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3395*da0073e9SAndroid Build Coastguard Worker    def test_cycles(self):
3396*da0073e9SAndroid Build Coastguard Worker        fired = False
3397*da0073e9SAndroid Build Coastguard Worker
3398*da0073e9SAndroid Build Coastguard Worker        def observer(html):
3399*da0073e9SAndroid Build Coastguard Worker            nonlocal fired
3400*da0073e9SAndroid Build Coastguard Worker            fired = True
3401*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("torch.Tensor" in html)
3402*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("test_cuda" in html)
3403*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("cell_contents" in html)
3404*da0073e9SAndroid Build Coastguard Worker
3405*da0073e9SAndroid Build Coastguard Worker        disarm = observe_tensor_cycles(observer)
3406*da0073e9SAndroid Build Coastguard Worker
3407*da0073e9SAndroid Build Coastguard Worker        def noop():
3408*da0073e9SAndroid Build Coastguard Worker            pass
3409*da0073e9SAndroid Build Coastguard Worker
3410*da0073e9SAndroid Build Coastguard Worker        try:
3411*da0073e9SAndroid Build Coastguard Worker
3412*da0073e9SAndroid Build Coastguard Worker            def create():
3413*da0073e9SAndroid Build Coastguard Worker                x = torch.empty(3, 4, device="cuda")
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Worker                def foo(p):
3416*da0073e9SAndroid Build Coastguard Worker                    if p:
3417*da0073e9SAndroid Build Coastguard Worker                        return foo(not p)
3418*da0073e9SAndroid Build Coastguard Worker                    else:
3419*da0073e9SAndroid Build Coastguard Worker                        return x
3420*da0073e9SAndroid Build Coastguard Worker
3421*da0073e9SAndroid Build Coastguard Worker                return foo
3422*da0073e9SAndroid Build Coastguard Worker
3423*da0073e9SAndroid Build Coastguard Worker            create()
3424*da0073e9SAndroid Build Coastguard Worker            gc.collect()
3425*da0073e9SAndroid Build Coastguard Worker            # the callback has to run outside of the collect
3426*da0073e9SAndroid Build Coastguard Worker            # call so it doesn't actual fire until the next
3427*da0073e9SAndroid Build Coastguard Worker            # method call after a gc.collect
3428*da0073e9SAndroid Build Coastguard Worker            noop()
3429*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(fired)
3430*da0073e9SAndroid Build Coastguard Worker        finally:
3431*da0073e9SAndroid Build Coastguard Worker            disarm()
3432*da0073e9SAndroid Build Coastguard Worker
3433*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3434*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3435*da0073e9SAndroid Build Coastguard Worker    )
3436*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3437*da0073e9SAndroid Build Coastguard Worker    def test_memory_plots(self):
3438*da0073e9SAndroid Build Coastguard Worker        for context, stacks in (
3439*da0073e9SAndroid Build Coastguard Worker            ("all", "all" if IS_LINUX else "python"),
3440*da0073e9SAndroid Build Coastguard Worker            ("all", "python"),
3441*da0073e9SAndroid Build Coastguard Worker            (None, "python"),
3442*da0073e9SAndroid Build Coastguard Worker        ):
3443*da0073e9SAndroid Build Coastguard Worker            try:
3444*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory.empty_cache()
3445*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._record_memory_history(
3446*da0073e9SAndroid Build Coastguard Worker                    "all", context=context, stacks=stacks
3447*da0073e9SAndroid Build Coastguard Worker                )
3448*da0073e9SAndroid Build Coastguard Worker
3449*da0073e9SAndroid Build Coastguard Worker                def run():
3450*da0073e9SAndroid Build Coastguard Worker                    x = torch.rand(128, 128, device="cuda")
3451*da0073e9SAndroid Build Coastguard Worker                    x * x + x * x
3452*da0073e9SAndroid Build Coastguard Worker
3453*da0073e9SAndroid Build Coastguard Worker                run()
3454*da0073e9SAndroid Build Coastguard Worker                cpp = stacks == "all"
3455*da0073e9SAndroid Build Coastguard Worker                record_context = context is not None
3456*da0073e9SAndroid Build Coastguard Worker                ss = torch.cuda.memory._snapshot()
3457*da0073e9SAndroid Build Coastguard Worker
3458*da0073e9SAndroid Build Coastguard Worker                tplot = trace_plot(ss)
3459*da0073e9SAndroid Build Coastguard Worker                splot = segment_plot(ss)
3460*da0073e9SAndroid Build Coastguard Worker                text = json.dumps(ss)
3461*da0073e9SAndroid Build Coastguard Worker
3462*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(record_context == ("test_memory_plots" in text))
3463*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(cpp == ("::rand" in text))
3464*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(str(128 * 128 * 4) in text)
3465*da0073e9SAndroid Build Coastguard Worker
3466*da0073e9SAndroid Build Coastguard Worker            finally:
3467*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._record_memory_history(None)
3468*da0073e9SAndroid Build Coastguard Worker
3469*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3470*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3471*da0073e9SAndroid Build Coastguard Worker    )
3472*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3473*da0073e9SAndroid Build Coastguard Worker    def test_memory_plots_free_stack(self):
3474*da0073e9SAndroid Build Coastguard Worker        for context in ["alloc", "all", "state"]:
3475*da0073e9SAndroid Build Coastguard Worker            try:
3476*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory.empty_cache()
3477*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._record_memory_history(context=context)
3478*da0073e9SAndroid Build Coastguard Worker                x = None
3479*da0073e9SAndroid Build Coastguard Worker
3480*da0073e9SAndroid Build Coastguard Worker                def thealloc():
3481*da0073e9SAndroid Build Coastguard Worker                    nonlocal x
3482*da0073e9SAndroid Build Coastguard Worker                    x = torch.rand(3, 4, device="cuda")
3483*da0073e9SAndroid Build Coastguard Worker
3484*da0073e9SAndroid Build Coastguard Worker                def thefree():
3485*da0073e9SAndroid Build Coastguard Worker                    nonlocal x
3486*da0073e9SAndroid Build Coastguard Worker                    del x
3487*da0073e9SAndroid Build Coastguard Worker
3488*da0073e9SAndroid Build Coastguard Worker                thealloc()
3489*da0073e9SAndroid Build Coastguard Worker                thefree()
3490*da0073e9SAndroid Build Coastguard Worker                ss = json.dumps(torch.cuda.memory._snapshot())
3491*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(("thefree" in ss) == (context == "all"))
3492*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(("thealloc" in ss) == (context != "state"))
3493*da0073e9SAndroid Build Coastguard Worker            finally:
3494*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._record_memory_history(None)
3495*da0073e9SAndroid Build Coastguard Worker
3496*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3497*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3498*da0073e9SAndroid Build Coastguard Worker    )
3499*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3500*da0073e9SAndroid Build Coastguard Worker    def test_memory_plots_history_context(self):
3501*da0073e9SAndroid Build Coastguard Worker        try:
3502*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory.empty_cache()
3503*da0073e9SAndroid Build Coastguard Worker            x = None
3504*da0073e9SAndroid Build Coastguard Worker
3505*da0073e9SAndroid Build Coastguard Worker            def should_capture1():
3506*da0073e9SAndroid Build Coastguard Worker                nonlocal x
3507*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(4, 4, device="cuda")
3508*da0073e9SAndroid Build Coastguard Worker
3509*da0073e9SAndroid Build Coastguard Worker            def should_not_capture():
3510*da0073e9SAndroid Build Coastguard Worker                nonlocal x
3511*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(3, 4, device="cuda")
3512*da0073e9SAndroid Build Coastguard Worker
3513*da0073e9SAndroid Build Coastguard Worker            def should_capture2():
3514*da0073e9SAndroid Build Coastguard Worker                nonlocal x
3515*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(4, 4, device="cuda")
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker            # Recording with context and python call stacks should capture the call stack.
3518*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(context="all", stacks="python")
3519*da0073e9SAndroid Build Coastguard Worker            should_capture1()
3520*da0073e9SAndroid Build Coastguard Worker            # Recording with context=None should not capture the call stack.
3521*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(context=None)
3522*da0073e9SAndroid Build Coastguard Worker            should_not_capture()
3523*da0073e9SAndroid Build Coastguard Worker            # Recording with context and python call stacks should capture the call stack.
3524*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(context="all", stacks="python")
3525*da0073e9SAndroid Build Coastguard Worker            should_capture2()
3526*da0073e9SAndroid Build Coastguard Worker
3527*da0073e9SAndroid Build Coastguard Worker            ss = json.dumps(torch.cuda.memory._snapshot())
3528*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("should_capture1" in ss)
3529*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("should_not_capture" not in ss)
3530*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("should_capture2" in ss)
3531*da0073e9SAndroid Build Coastguard Worker        finally:
3532*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(None)
3533*da0073e9SAndroid Build Coastguard Worker
3534*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3535*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3536*da0073e9SAndroid Build Coastguard Worker    )
3537*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3538*da0073e9SAndroid Build Coastguard Worker    def test_memory_plots_free_segment_stack(self):
3539*da0073e9SAndroid Build Coastguard Worker        for context in ["alloc", "all", "state"]:
3540*da0073e9SAndroid Build Coastguard Worker            try:
3541*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory.empty_cache()
3542*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._record_memory_history(context=context)
3543*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(3, 4, device="cuda")
3544*da0073e9SAndroid Build Coastguard Worker                del x
3545*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory.empty_cache()
3546*da0073e9SAndroid Build Coastguard Worker
3547*da0073e9SAndroid Build Coastguard Worker                ss = json.dumps(torch.cuda.memory._snapshot())
3548*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(("empty_cache" in ss) == (context == "all"))
3549*da0073e9SAndroid Build Coastguard Worker            finally:
3550*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._record_memory_history(None)
3551*da0073e9SAndroid Build Coastguard Worker
3552*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3553*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3554*da0073e9SAndroid Build Coastguard Worker    )
3555*da0073e9SAndroid Build Coastguard Worker    def test_memory_snapshot_script(self):
3556*da0073e9SAndroid Build Coastguard Worker        try:
3557*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory.empty_cache()
3558*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history("state", stacks="python")
3559*da0073e9SAndroid Build Coastguard Worker
3560*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3561*da0073e9SAndroid Build Coastguard Worker            def foo():
3562*da0073e9SAndroid Build Coastguard Worker                return torch.rand(311, 411, device="cuda")
3563*da0073e9SAndroid Build Coastguard Worker
3564*da0073e9SAndroid Build Coastguard Worker            x = foo()
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker            ss = torch.cuda.memory._snapshot()["segments"]
3567*da0073e9SAndroid Build Coastguard Worker            found_it = False
3568*da0073e9SAndroid Build Coastguard Worker            for seg in ss:
3569*da0073e9SAndroid Build Coastguard Worker                for b in seg["blocks"]:
3570*da0073e9SAndroid Build Coastguard Worker                    if b["requested_size"] == 311 * 411 * 4:
3571*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(b["frames"][0]["name"] == "foo")
3572*da0073e9SAndroid Build Coastguard Worker                        found_it = True
3573*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(found_it)
3574*da0073e9SAndroid Build Coastguard Worker
3575*da0073e9SAndroid Build Coastguard Worker        finally:
3576*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(None)
3577*da0073e9SAndroid Build Coastguard Worker
3578*da0073e9SAndroid Build Coastguard Worker    def test_max_split_expandable(self):
3579*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3580*da0073e9SAndroid Build Coastguard Worker        mb = 1024 * 1024
3581*da0073e9SAndroid Build Coastguard Worker        _, all_memory = torch.cuda.memory.mem_get_info()
3582*da0073e9SAndroid Build Coastguard Worker        total_allowed = 120 * mb
3583*da0073e9SAndroid Build Coastguard Worker        fraction_allowed = total_allowed / all_memory
3584*da0073e9SAndroid Build Coastguard Worker        assert int(fraction_allowed * all_memory) == total_allowed
3585*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
3586*da0073e9SAndroid Build Coastguard Worker
3587*da0073e9SAndroid Build Coastguard Worker        def alloc(n):
3588*da0073e9SAndroid Build Coastguard Worker            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
3589*da0073e9SAndroid Build Coastguard Worker
3590*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3591*da0073e9SAndroid Build Coastguard Worker            "expandable_segments:False,max_split_size_mb:40"
3592*da0073e9SAndroid Build Coastguard Worker        )
3593*da0073e9SAndroid Build Coastguard Worker        a = alloc(40)
3594*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3595*da0073e9SAndroid Build Coastguard Worker            "expandable_segments:True,max_split_size_mb:40"
3596*da0073e9SAndroid Build Coastguard Worker        )
3597*da0073e9SAndroid Build Coastguard Worker        b = alloc(40)
3598*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3599*da0073e9SAndroid Build Coastguard Worker            "expandable_segments:False,max_split_size_mb:40"
3600*da0073e9SAndroid Build Coastguard Worker        )
3601*da0073e9SAndroid Build Coastguard Worker        c = alloc(40)
3602*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch.OutOfMemoryError):
3603*da0073e9SAndroid Build Coastguard Worker            alloc(40)
3604*da0073e9SAndroid Build Coastguard Worker        del a, b, c
3605*da0073e9SAndroid Build Coastguard Worker        # force release_cached_blocks to run with some expandable segments in the free list
3606*da0073e9SAndroid Build Coastguard Worker        alloc(120)
3607*da0073e9SAndroid Build Coastguard Worker
3608*da0073e9SAndroid Build Coastguard Worker    def test_garbage_collect_expandable(self):
3609*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3610*da0073e9SAndroid Build Coastguard Worker        mb = 1024 * 1024
3611*da0073e9SAndroid Build Coastguard Worker        _, all_memory = torch.cuda.memory.mem_get_info()
3612*da0073e9SAndroid Build Coastguard Worker        total_allowed = 120 * mb
3613*da0073e9SAndroid Build Coastguard Worker        fraction_allowed = total_allowed / all_memory
3614*da0073e9SAndroid Build Coastguard Worker        assert int(fraction_allowed * all_memory) == total_allowed
3615*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
3616*da0073e9SAndroid Build Coastguard Worker
3617*da0073e9SAndroid Build Coastguard Worker        def alloc(n):
3618*da0073e9SAndroid Build Coastguard Worker            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
3619*da0073e9SAndroid Build Coastguard Worker
3620*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3621*da0073e9SAndroid Build Coastguard Worker            "expandable_segments:False,garbage_collection_threshold:0.5"
3622*da0073e9SAndroid Build Coastguard Worker        )
3623*da0073e9SAndroid Build Coastguard Worker        a = alloc(40)
3624*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3625*da0073e9SAndroid Build Coastguard Worker            "expandable_segments:True,garbage_collection_threshold:0.5"
3626*da0073e9SAndroid Build Coastguard Worker        )
3627*da0073e9SAndroid Build Coastguard Worker        b = alloc(40)
3628*da0073e9SAndroid Build Coastguard Worker        del a, b
3629*da0073e9SAndroid Build Coastguard Worker        # causes GC to run. The expandable segment block will be split
3630*da0073e9SAndroid Build Coastguard Worker        # so GC would not attempt to free it anyway, but this at least makes sure
3631*da0073e9SAndroid Build Coastguard Worker        # expandable_segment blocks can be in the free list when this is called.
3632*da0073e9SAndroid Build Coastguard Worker        alloc(80)
3633*da0073e9SAndroid Build Coastguard Worker
3634*da0073e9SAndroid Build Coastguard Worker    def test_allocator_settings(self):
3635*da0073e9SAndroid Build Coastguard Worker        def power2_div(size, div_factor):
3636*da0073e9SAndroid Build Coastguard Worker            pow2 = 1
3637*da0073e9SAndroid Build Coastguard Worker            while pow2 < size:
3638*da0073e9SAndroid Build Coastguard Worker                pow2 = pow2 * 2
3639*da0073e9SAndroid Build Coastguard Worker            if pow2 == size:
3640*da0073e9SAndroid Build Coastguard Worker                return pow2
3641*da0073e9SAndroid Build Coastguard Worker            step = pow2 / 2 / div_factor
3642*da0073e9SAndroid Build Coastguard Worker            ret = pow2 / 2
3643*da0073e9SAndroid Build Coastguard Worker            while ret < size:
3644*da0073e9SAndroid Build Coastguard Worker                ret = ret + step
3645*da0073e9SAndroid Build Coastguard Worker            return ret
3646*da0073e9SAndroid Build Coastguard Worker
3647*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3648*da0073e9SAndroid Build Coastguard Worker        key_allocated = (
3649*da0073e9SAndroid Build Coastguard Worker            "active_bytes.all.allocated"
3650*da0073e9SAndroid Build Coastguard Worker            if not TEST_CUDAMALLOCASYNC
3651*da0073e9SAndroid Build Coastguard Worker            else "allocated_bytes.all.current"
3652*da0073e9SAndroid Build Coastguard Worker        )
3653*da0073e9SAndroid Build Coastguard Worker        key_requested = "requested_bytes.all.allocated"
3654*da0073e9SAndroid Build Coastguard Worker
3655*da0073e9SAndroid Build Coastguard Worker        nelems = 21 * 1024 * 1024
3656*da0073e9SAndroid Build Coastguard Worker        nbytes = 4 * nelems  # floats are 4 bytes
3657*da0073e9SAndroid Build Coastguard Worker
3658*da0073e9SAndroid Build Coastguard Worker        nelems_big = 100 * 1024 * 1024
3659*da0073e9SAndroid Build Coastguard Worker        nbytes_big = 4 * nelems_big  # floats are 4 bytes
3660*da0073e9SAndroid Build Coastguard Worker
3661*da0073e9SAndroid Build Coastguard Worker        start_mem = torch.cuda.memory_stats()[key_allocated]
3662*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings("")
3663*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(nelems, device="cuda")
3664*da0073e9SAndroid Build Coastguard Worker
3665*da0073e9SAndroid Build Coastguard Worker        # test roundup_power2_divisions single value syntax
3666*da0073e9SAndroid Build Coastguard Worker        reg_mem = torch.cuda.memory_stats()[key_allocated]
3667*da0073e9SAndroid Build Coastguard Worker        start_requested = torch.cuda.memory_stats()[key_requested]
3668*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4")
3669*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(nelems, device="cuda")
3670*da0073e9SAndroid Build Coastguard Worker
3671*da0073e9SAndroid Build Coastguard Worker        pow2_div4_mem = torch.cuda.memory_stats()[key_allocated]
3672*da0073e9SAndroid Build Coastguard Worker        current_requested = torch.cuda.memory_stats()[key_requested]
3673*da0073e9SAndroid Build Coastguard Worker
3674*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reg_mem - start_mem == nbytes)
3675*da0073e9SAndroid Build Coastguard Worker        if not TEST_CUDAMALLOCASYNC:
3676*da0073e9SAndroid Build Coastguard Worker            # not supported with the cudaMallocAsync backend
3677*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
3678*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(current_requested - start_requested == nbytes)
3679*da0073e9SAndroid Build Coastguard Worker
3680*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
3681*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3682*da0073e9SAndroid Build Coastguard Worker            "garbage_collection_threshold:0.5,max_split_size_mb:40"
3683*da0073e9SAndroid Build Coastguard Worker        )
3684*da0073e9SAndroid Build Coastguard Worker
3685*da0073e9SAndroid Build Coastguard Worker        # should have reset the power2 divisions now
3686*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3687*da0073e9SAndroid Build Coastguard Worker        start_mem = torch.cuda.memory_stats()[key_allocated]
3688*da0073e9SAndroid Build Coastguard Worker        z = torch.rand(nelems, device="cuda")
3689*da0073e9SAndroid Build Coastguard Worker        reg_mem = torch.cuda.memory_stats()[key_allocated]
3690*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reg_mem - start_mem == nbytes)
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker        # roundup_power2_divisions knob array syntax
3693*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3694*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings(
3695*da0073e9SAndroid Build Coastguard Worker            "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]"
3696*da0073e9SAndroid Build Coastguard Worker        )
3697*da0073e9SAndroid Build Coastguard Worker        start_mem = torch.cuda.memory_stats()[key_allocated]
3698*da0073e9SAndroid Build Coastguard Worker        w = torch.rand(nelems, device="cuda")
3699*da0073e9SAndroid Build Coastguard Worker
3700*da0073e9SAndroid Build Coastguard Worker        pow2_div8_mem = torch.cuda.memory_stats()[key_allocated]
3701*da0073e9SAndroid Build Coastguard Worker        if not TEST_CUDAMALLOCASYNC:
3702*da0073e9SAndroid Build Coastguard Worker            # not supported with the cudaMallocAsync backend
3703*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8))
3704*da0073e9SAndroid Build Coastguard Worker
3705*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3706*da0073e9SAndroid Build Coastguard Worker        start_mem = torch.cuda.memory_stats()[key_allocated]
3707*da0073e9SAndroid Build Coastguard Worker        v = torch.rand(nelems_big, device="cuda")
3708*da0073e9SAndroid Build Coastguard Worker
3709*da0073e9SAndroid Build Coastguard Worker        pow2_div2_mem = torch.cuda.memory_stats()[key_allocated]
3710*da0073e9SAndroid Build Coastguard Worker        if not TEST_CUDAMALLOCASYNC:
3711*da0073e9SAndroid Build Coastguard Worker            # not supported with the cudaMallocAsync backend
3712*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2))
3713*da0073e9SAndroid Build Coastguard Worker
3714*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory.empty_cache()
3715*da0073e9SAndroid Build Coastguard Worker        torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True")
3716*da0073e9SAndroid Build Coastguard Worker        start_mem = torch.cuda.memory_stats()[key_allocated]
3717*da0073e9SAndroid Build Coastguard Worker        w = torch.rand(nelems, device="cuda")
3718*da0073e9SAndroid Build Coastguard Worker        reg_mem = torch.cuda.memory_stats()[key_allocated]
3719*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reg_mem - start_mem == nbytes)
3720*da0073e9SAndroid Build Coastguard Worker
3721*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3722*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings("foo:1,bar:2")
3723*da0073e9SAndroid Build Coastguard Worker
3724*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3725*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings(
3726*da0073e9SAndroid Build Coastguard Worker                "garbage_collection_threshold:1.2"
3727*da0073e9SAndroid Build Coastguard Worker            )
3728*da0073e9SAndroid Build Coastguard Worker
3729*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3730*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings("max_split_size_mb:2")
3731*da0073e9SAndroid Build Coastguard Worker
3732*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3733*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none")
3734*da0073e9SAndroid Build Coastguard Worker
3735*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3736*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings(
3737*da0073e9SAndroid Build Coastguard Worker                "pinned_use_cuda_host_register:none"
3738*da0073e9SAndroid Build Coastguard Worker            )
3739*da0073e9SAndroid Build Coastguard Worker
3740*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3741*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings(
3742*da0073e9SAndroid Build Coastguard Worker                "pinned_num_register_threads:none"
3743*da0073e9SAndroid Build Coastguard Worker            )
3744*da0073e9SAndroid Build Coastguard Worker
3745*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3746*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings(
3747*da0073e9SAndroid Build Coastguard Worker                "pinned_num_register_threads:1024"
3748*da0073e9SAndroid Build Coastguard Worker            )
3749*da0073e9SAndroid Build Coastguard Worker
3750*da0073e9SAndroid Build Coastguard Worker    @parametrize("max_split_size_mb_setting", [False, True])
3751*da0073e9SAndroid Build Coastguard Worker    def test_raises_oom(self, max_split_size_mb_setting):
3752*da0073e9SAndroid Build Coastguard Worker        if max_split_size_mb_setting:
3753*da0073e9SAndroid Build Coastguard Worker            # CudaCachingAllocator does early return when searching available blocks
3754*da0073e9SAndroid Build Coastguard Worker            # if max_split_size_mb is not set
3755*da0073e9SAndroid Build Coastguard Worker            # Setting this triggers more parts of the code
3756*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024")
3757*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory.empty_cache()
3758*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch.cuda.OutOfMemoryError):
3759*da0073e9SAndroid Build Coastguard Worker            torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
3760*da0073e9SAndroid Build Coastguard Worker
3761*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3762*da0073e9SAndroid Build Coastguard Worker        not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux"
3763*da0073e9SAndroid Build Coastguard Worker    )
3764*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3765*da0073e9SAndroid Build Coastguard Worker        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3766*da0073e9SAndroid Build Coastguard Worker    )
3767*da0073e9SAndroid Build Coastguard Worker    def test_cpp_memory_snapshot_pickle(self):
3768*da0073e9SAndroid Build Coastguard Worker        from torch.utils.cpp_extension import load_inline
3769*da0073e9SAndroid Build Coastguard Worker
3770*da0073e9SAndroid Build Coastguard Worker        source = """
3771*da0073e9SAndroid Build Coastguard Worker        #include <torch/csrc/cuda/memory_snapshot.h>
3772*da0073e9SAndroid Build Coastguard Worker        py::object do_snapshot() {
3773*da0073e9SAndroid Build Coastguard Worker            std::string data = torch::cuda::_memory_snapshot_pickled();
3774*da0073e9SAndroid Build Coastguard Worker            return py::bytes(data);
3775*da0073e9SAndroid Build Coastguard Worker        }
3776*da0073e9SAndroid Build Coastguard Worker        void record(bool e, bool ctx) {
3777*da0073e9SAndroid Build Coastguard Worker            torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx);
3778*da0073e9SAndroid Build Coastguard Worker        }
3779*da0073e9SAndroid Build Coastguard Worker        """
3780*da0073e9SAndroid Build Coastguard Worker        m = load_inline(
3781*da0073e9SAndroid Build Coastguard Worker            name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"]
3782*da0073e9SAndroid Build Coastguard Worker        )
3783*da0073e9SAndroid Build Coastguard Worker        for ctx in (False, True):
3784*da0073e9SAndroid Build Coastguard Worker            try:
3785*da0073e9SAndroid Build Coastguard Worker                m.record(True, ctx)
3786*da0073e9SAndroid Build Coastguard Worker
3787*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
3788*da0073e9SAndroid Build Coastguard Worker                def the_script_fn():
3789*da0073e9SAndroid Build Coastguard Worker                    return torch.rand(311, 411, device="cuda")
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker                def run():
3792*da0073e9SAndroid Build Coastguard Worker                    t = the_script_fn()
3793*da0073e9SAndroid Build Coastguard Worker                    return pickle.loads(m.do_snapshot())
3794*da0073e9SAndroid Build Coastguard Worker
3795*da0073e9SAndroid Build Coastguard Worker                mem = run()
3796*da0073e9SAndroid Build Coastguard Worker                found = False
3797*da0073e9SAndroid Build Coastguard Worker                for s in mem["segments"]:
3798*da0073e9SAndroid Build Coastguard Worker                    for b in s["blocks"]:
3799*da0073e9SAndroid Build Coastguard Worker                        if b["state"] == "active_allocated":
3800*da0073e9SAndroid Build Coastguard Worker                            if b["requested_size"] == 311 * 411 * 4:
3801*da0073e9SAndroid Build Coastguard Worker                                if ctx:
3802*da0073e9SAndroid Build Coastguard Worker                                    frame_text = str(b["frames"])
3803*da0073e9SAndroid Build Coastguard Worker                                    # C++ frame
3804*da0073e9SAndroid Build Coastguard Worker                                    self.assertTrue("::rand" in frame_text)
3805*da0073e9SAndroid Build Coastguard Worker                                    # script frame
3806*da0073e9SAndroid Build Coastguard Worker                                    self.assertTrue("the_script_fn" in frame_text)
3807*da0073e9SAndroid Build Coastguard Worker                                    # python frame
3808*da0073e9SAndroid Build Coastguard Worker                                    self.assertTrue("case.py" in frame_text)
3809*da0073e9SAndroid Build Coastguard Worker                                found = True
3810*da0073e9SAndroid Build Coastguard Worker                last_action = mem["device_traces"][0][-1]
3811*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(last_action["action"] == "alloc")
3812*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(last_action["size"] == 311 * 411 * 4)
3813*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(found)
3814*da0073e9SAndroid Build Coastguard Worker            finally:
3815*da0073e9SAndroid Build Coastguard Worker                m.record(False, False)
3816*da0073e9SAndroid Build Coastguard Worker
3817*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
3818*da0073e9SAndroid Build Coastguard Worker    def test_notifies_oom(self):
3819*da0073e9SAndroid Build Coastguard Worker        x = False
3820*da0073e9SAndroid Build Coastguard Worker
3821*da0073e9SAndroid Build Coastguard Worker        def cb(device, alloc, device_alloc, device_free):
3822*da0073e9SAndroid Build Coastguard Worker            nonlocal x
3823*da0073e9SAndroid Build Coastguard Worker            x = True
3824*da0073e9SAndroid Build Coastguard Worker
3825*da0073e9SAndroid Build Coastguard Worker        torch._C._cuda_attach_out_of_memory_observer(cb)
3826*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch.cuda.OutOfMemoryError):
3827*da0073e9SAndroid Build Coastguard Worker            torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
3828*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x)
3829*da0073e9SAndroid Build Coastguard Worker
3830*da0073e9SAndroid Build Coastguard Worker    def test_allocator_fuzz(self):
3831*da0073e9SAndroid Build Coastguard Worker        # fuzz
3832*da0073e9SAndroid Build Coastguard Worker        state = random.getstate()
3833*da0073e9SAndroid Build Coastguard Worker        random.seed(123)
3834*da0073e9SAndroid Build Coastguard Worker        N = 10000
3835*da0073e9SAndroid Build Coastguard Worker        try:
3836*da0073e9SAndroid Build Coastguard Worker            mem = []
3837*da0073e9SAndroid Build Coastguard Worker            total = 0
3838*da0073e9SAndroid Build Coastguard Worker            c = 0
3839*da0073e9SAndroid Build Coastguard Worker
3840*da0073e9SAndroid Build Coastguard Worker            def alloc():
3841*da0073e9SAndroid Build Coastguard Worker                nonlocal total, c
3842*da0073e9SAndroid Build Coastguard Worker                b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4)
3843*da0073e9SAndroid Build Coastguard Worker                mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda")))
3844*da0073e9SAndroid Build Coastguard Worker                c += 1
3845*da0073e9SAndroid Build Coastguard Worker                total += b
3846*da0073e9SAndroid Build Coastguard Worker
3847*da0073e9SAndroid Build Coastguard Worker            def free():
3848*da0073e9SAndroid Build Coastguard Worker                nonlocal total
3849*da0073e9SAndroid Build Coastguard Worker                idx = random.randrange(0, len(mem))
3850*da0073e9SAndroid Build Coastguard Worker                v, x = mem.pop(idx)
3851*da0073e9SAndroid Build Coastguard Worker                assert torch.all(v == x)
3852*da0073e9SAndroid Build Coastguard Worker                total -= x.numel()
3853*da0073e9SAndroid Build Coastguard Worker
3854*da0073e9SAndroid Build Coastguard Worker            choices = [alloc, free, torch.cuda.memory.empty_cache]
3855*da0073e9SAndroid Build Coastguard Worker            for i in range(N):
3856*da0073e9SAndroid Build Coastguard Worker                while total >= 1024 * 1024 * 1024 / (4 * 10):
3857*da0073e9SAndroid Build Coastguard Worker                    free()
3858*da0073e9SAndroid Build Coastguard Worker                (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1])
3859*da0073e9SAndroid Build Coastguard Worker                action()
3860*da0073e9SAndroid Build Coastguard Worker        finally:
3861*da0073e9SAndroid Build Coastguard Worker            random.setstate(state)
3862*da0073e9SAndroid Build Coastguard Worker
3863*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3864*da0073e9SAndroid Build Coastguard Worker    def test_nvml_get_handler(self):
3865*da0073e9SAndroid Build Coastguard Worker        if not torch.version.hip:
3866*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.cuda._get_pynvml_handler() is not None)
3867*da0073e9SAndroid Build Coastguard Worker        else:
3868*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.cuda._get_amdsmi_handler() is not None)
3869*da0073e9SAndroid Build Coastguard Worker
3870*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3871*da0073e9SAndroid Build Coastguard Worker    def test_temperature(self):
3872*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(0 <= torch.cuda.temperature() <= 150)
3873*da0073e9SAndroid Build Coastguard Worker
3874*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3875*da0073e9SAndroid Build Coastguard Worker    def test_power_draw(self):
3876*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.cuda.power_draw() >= 0)
3877*da0073e9SAndroid Build Coastguard Worker
3878*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3879*da0073e9SAndroid Build Coastguard Worker    def test_clock_speed(self):
3880*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.cuda.clock_rate() >= 0)
3881*da0073e9SAndroid Build Coastguard Worker
3882*da0073e9SAndroid Build Coastguard Worker
3883*da0073e9SAndroid Build Coastguard WorkerMIN_BLOCK_SIZE = 512
3884*da0073e9SAndroid Build Coastguard WorkerSMALL_SIZE = 1048576
3885*da0073e9SAndroid Build Coastguard WorkerSMALL_BUFFER = 2097152
3886*da0073e9SAndroid Build Coastguard WorkerLARGE_BUFFER = 20971520
3887*da0073e9SAndroid Build Coastguard Worker
3888*da0073e9SAndroid Build Coastguard Worker
3889*da0073e9SAndroid Build Coastguard Workerdef get_cudagraph_segments(pool_id):
3890*da0073e9SAndroid Build Coastguard Worker    segments = torch.cuda.memory_snapshot()
3891*da0073e9SAndroid Build Coastguard Worker    return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
3892*da0073e9SAndroid Build Coastguard Worker
3893*da0073e9SAndroid Build Coastguard Worker
3894*da0073e9SAndroid Build Coastguard Workerdef get_all_cudagraph_segments():
3895*da0073e9SAndroid Build Coastguard Worker    segments = torch.cuda.memory_snapshot()
3896*da0073e9SAndroid Build Coastguard Worker    return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
3897*da0073e9SAndroid Build Coastguard Worker
3898*da0073e9SAndroid Build Coastguard Worker
3899*da0073e9SAndroid Build Coastguard Workerdef cudagraphify(fn, inputs, pool=None):
3900*da0073e9SAndroid Build Coastguard Worker    if not TEST_CUDA_GRAPH:
3901*da0073e9SAndroid Build Coastguard Worker        raise unittest.SkipTest("cuda graph test is skipped")
3902*da0073e9SAndroid Build Coastguard Worker
3903*da0073e9SAndroid Build Coastguard Worker    torch.cuda.synchronize()
3904*da0073e9SAndroid Build Coastguard Worker    stream = torch.cuda.Stream()
3905*da0073e9SAndroid Build Coastguard Worker    stream.wait_stream(torch.cuda.current_stream())
3906*da0073e9SAndroid Build Coastguard Worker    with torch.cuda.stream(stream):
3907*da0073e9SAndroid Build Coastguard Worker        fn(*inputs)
3908*da0073e9SAndroid Build Coastguard Worker    stream.synchronize()
3909*da0073e9SAndroid Build Coastguard Worker    torch.cuda.current_stream().wait_stream(stream)
3910*da0073e9SAndroid Build Coastguard Worker    torch.cuda.synchronize()
3911*da0073e9SAndroid Build Coastguard Worker
3912*da0073e9SAndroid Build Coastguard Worker    graph = torch.cuda.CUDAGraph()
3913*da0073e9SAndroid Build Coastguard Worker    with torch.cuda.graph(graph, stream=stream, pool=pool):
3914*da0073e9SAndroid Build Coastguard Worker        static_outputs = fn(*inputs)
3915*da0073e9SAndroid Build Coastguard Worker
3916*da0073e9SAndroid Build Coastguard Worker    return graph, static_outputs
3917*da0073e9SAndroid Build Coastguard Worker
3918*da0073e9SAndroid Build Coastguard Worker
3919*da0073e9SAndroid Build Coastguard Workerdef int8_cuda(size):
3920*da0073e9SAndroid Build Coastguard Worker    return torch.ones([size], device="cuda", dtype=torch.uint8)
3921*da0073e9SAndroid Build Coastguard Worker
3922*da0073e9SAndroid Build Coastguard Worker
3923*da0073e9SAndroid Build Coastguard Workerdef live_blocks(pool_id):
3924*da0073e9SAndroid Build Coastguard Worker    blocks = 0
3925*da0073e9SAndroid Build Coastguard Worker    seg = get_cudagraph_segments(pool_id)
3926*da0073e9SAndroid Build Coastguard Worker    for segment in get_cudagraph_segments(pool_id):
3927*da0073e9SAndroid Build Coastguard Worker        for block in segment["blocks"]:
3928*da0073e9SAndroid Build Coastguard Worker            blocks += block["state"] == "active_allocated"
3929*da0073e9SAndroid Build Coastguard Worker    return blocks
3930*da0073e9SAndroid Build Coastguard Worker
3931*da0073e9SAndroid Build Coastguard Worker
3932*da0073e9SAndroid Build Coastguard Workerdef tensor_metadata(x):
3933*da0073e9SAndroid Build Coastguard Worker    return {
3934*da0073e9SAndroid Build Coastguard Worker        "nbytes": x.untyped_storage().nbytes(),
3935*da0073e9SAndroid Build Coastguard Worker        "data_ptr": x.untyped_storage().data_ptr(),
3936*da0073e9SAndroid Build Coastguard Worker        "size": x.shape,
3937*da0073e9SAndroid Build Coastguard Worker        "stride": x.stride(),
3938*da0073e9SAndroid Build Coastguard Worker        "dtype": x.dtype,
3939*da0073e9SAndroid Build Coastguard Worker        "device": x.device,
3940*da0073e9SAndroid Build Coastguard Worker        "storage_offset": x.storage_offset(),
3941*da0073e9SAndroid Build Coastguard Worker    }
3942*da0073e9SAndroid Build Coastguard Worker
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Workerdef reconstruct_from_tensor_metadata(metadata):
3945*da0073e9SAndroid Build Coastguard Worker    s = torch._C._construct_storage_from_data_pointer(
3946*da0073e9SAndroid Build Coastguard Worker        metadata["data_ptr"], metadata["device"], metadata["nbytes"]
3947*da0073e9SAndroid Build Coastguard Worker    )
3948*da0073e9SAndroid Build Coastguard Worker    t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"])
3949*da0073e9SAndroid Build Coastguard Worker    t.set_(
3950*da0073e9SAndroid Build Coastguard Worker        source=s,
3951*da0073e9SAndroid Build Coastguard Worker        storage_offset=metadata["storage_offset"],
3952*da0073e9SAndroid Build Coastguard Worker        size=metadata["size"],
3953*da0073e9SAndroid Build Coastguard Worker        stride=metadata["stride"],
3954*da0073e9SAndroid Build Coastguard Worker    )
3955*da0073e9SAndroid Build Coastguard Worker    return t
3956*da0073e9SAndroid Build Coastguard Worker
3957*da0073e9SAndroid Build Coastguard Worker
3958*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
3959*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest
3960*da0073e9SAndroid Build Coastguard Workerclass TestBlockStateAbsorption(TestCase):
3961*da0073e9SAndroid Build Coastguard Worker    @property
3962*da0073e9SAndroid Build Coastguard Worker    def expandable_segments(self):
3963*da0073e9SAndroid Build Coastguard Worker        return EXPANDABLE_SEGMENTS
3964*da0073e9SAndroid Build Coastguard Worker
3965*da0073e9SAndroid Build Coastguard Worker    def checkCheckpointedBlock(self, before_block, after_block):
3966*da0073e9SAndroid Build Coastguard Worker        for field in ("size", "state"):
3967*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(before_block[field], after_block[field])
3968*da0073e9SAndroid Build Coastguard Worker
3969*da0073e9SAndroid Build Coastguard Worker    def checkCheckpointedState(self, before_segments, after_segments):
3970*da0073e9SAndroid Build Coastguard Worker        # after may contain additional segments, but all of the segments in before
3971*da0073e9SAndroid Build Coastguard Worker        # should be exactly equivalent to after
3972*da0073e9SAndroid Build Coastguard Worker        after_ptr_to_segment = {
3973*da0073e9SAndroid Build Coastguard Worker            segment["address"]: segment for segment in after_segments
3974*da0073e9SAndroid Build Coastguard Worker        }
3975*da0073e9SAndroid Build Coastguard Worker
3976*da0073e9SAndroid Build Coastguard Worker        for before_segment in before_segments:
3977*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(before_segment["address"] in after_ptr_to_segment)
3978*da0073e9SAndroid Build Coastguard Worker            after_segment = after_ptr_to_segment[before_segment["address"]]
3979*da0073e9SAndroid Build Coastguard Worker
3980*da0073e9SAndroid Build Coastguard Worker            for field in (
3981*da0073e9SAndroid Build Coastguard Worker                "device",
3982*da0073e9SAndroid Build Coastguard Worker                "total_size",
3983*da0073e9SAndroid Build Coastguard Worker                "allocated_size",
3984*da0073e9SAndroid Build Coastguard Worker                "active_size",
3985*da0073e9SAndroid Build Coastguard Worker                "segment_type",
3986*da0073e9SAndroid Build Coastguard Worker                "segment_pool_id",
3987*da0073e9SAndroid Build Coastguard Worker            ):
3988*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(before_segment[field], after_segment[field])
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3991*da0073e9SAndroid Build Coastguard Worker                len(before_segment["blocks"]), len(after_segment["blocks"])
3992*da0073e9SAndroid Build Coastguard Worker            )
3993*da0073e9SAndroid Build Coastguard Worker            for before_block, after_block in zip(
3994*da0073e9SAndroid Build Coastguard Worker                before_segment["blocks"], after_segment["blocks"]
3995*da0073e9SAndroid Build Coastguard Worker            ):
3996*da0073e9SAndroid Build Coastguard Worker                self.checkCheckpointedBlock(before_block, after_block)
3997*da0073e9SAndroid Build Coastguard Worker
3998*da0073e9SAndroid Build Coastguard Worker    @staticmethod
3999*da0073e9SAndroid Build Coastguard Worker    def setCheckpointPoolState(
4000*da0073e9SAndroid Build Coastguard Worker        device, state, stale_storages_ptr, storages_deleters=None
4001*da0073e9SAndroid Build Coastguard Worker    ):
4002*da0073e9SAndroid Build Coastguard Worker        stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr]
4003*da0073e9SAndroid Build Coastguard Worker        storages_deleters = (
4004*da0073e9SAndroid Build Coastguard Worker            []
4005*da0073e9SAndroid Build Coastguard Worker            if not storages_deleters
4006*da0073e9SAndroid Build Coastguard Worker            else [t.untyped_storage()._cdata for t in storages_deleters]
4007*da0073e9SAndroid Build Coastguard Worker        )
4008*da0073e9SAndroid Build Coastguard Worker        torch._C._cuda_setCheckpointPoolState(
4009*da0073e9SAndroid Build Coastguard Worker            device, state, stale_storages_ptr, storages_deleters
4010*da0073e9SAndroid Build Coastguard Worker        )
4011*da0073e9SAndroid Build Coastguard Worker
4012*da0073e9SAndroid Build Coastguard Worker    def checkFunction(self, fn, inputs, pool=None):
4013*da0073e9SAndroid Build Coastguard Worker        graph, outputs = cudagraphify(fn, inputs, pool=pool)
4014*da0073e9SAndroid Build Coastguard Worker
4015*da0073e9SAndroid Build Coastguard Worker        pool_id = graph.pool()
4016*da0073e9SAndroid Build Coastguard Worker        device = outputs[0].device.index
4017*da0073e9SAndroid Build Coastguard Worker
4018*da0073e9SAndroid Build Coastguard Worker        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4019*da0073e9SAndroid Build Coastguard Worker
4020*da0073e9SAndroid Build Coastguard Worker        state = torch._C._cuda_getCheckpointState(device, pool_id)
4021*da0073e9SAndroid Build Coastguard Worker        self.setCheckpointPoolState(device, state, [], [])
4022*da0073e9SAndroid Build Coastguard Worker
4023*da0073e9SAndroid Build Coastguard Worker        self.checkCheckpointedState(
4024*da0073e9SAndroid Build Coastguard Worker            segments_before_checkpoint, get_cudagraph_segments(pool_id)
4025*da0073e9SAndroid Build Coastguard Worker        )
4026*da0073e9SAndroid Build Coastguard Worker
4027*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
4028*da0073e9SAndroid Build Coastguard Worker        super().setUp()
4029*da0073e9SAndroid Build Coastguard Worker        self.segment_length = len(get_all_cudagraph_segments())
4030*da0073e9SAndroid Build Coastguard Worker
4031*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
4032*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
4033*da0073e9SAndroid Build Coastguard Worker        gc.collect()
4034*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
4035*da0073e9SAndroid Build Coastguard Worker
4036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length)
4037*da0073e9SAndroid Build Coastguard Worker
4038*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
4039*da0073e9SAndroid Build Coastguard Worker
4040*da0073e9SAndroid Build Coastguard Worker    def test_simple(self):
4041*da0073e9SAndroid Build Coastguard Worker        def foo():
4042*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8)
4043*da0073e9SAndroid Build Coastguard Worker            x = x + x
4044*da0073e9SAndroid Build Coastguard Worker            x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE)
4045*da0073e9SAndroid Build Coastguard Worker            y = int8_cuda(SMALL_SIZE) + x1
4046*da0073e9SAndroid Build Coastguard Worker            z = int8_cuda(SMALL_SIZE)
4047*da0073e9SAndroid Build Coastguard Worker            return x, y, z
4048*da0073e9SAndroid Build Coastguard Worker
4049*da0073e9SAndroid Build Coastguard Worker        self.checkFunction(foo, [])
4050*da0073e9SAndroid Build Coastguard Worker
4051*da0073e9SAndroid Build Coastguard Worker    def test_allocated_in_middle_of_segment(self):
4052*da0073e9SAndroid Build Coastguard Worker        def foo():
4053*da0073e9SAndroid Build Coastguard Worker            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4054*da0073e9SAndroid Build Coastguard Worker            return small_buffers[5].add_(2)
4055*da0073e9SAndroid Build Coastguard Worker
4056*da0073e9SAndroid Build Coastguard Worker        self.checkFunction(foo, [])
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker    def test_multiple_middle_allocations(self):
4059*da0073e9SAndroid Build Coastguard Worker        def foo():
4060*da0073e9SAndroid Build Coastguard Worker            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4061*da0073e9SAndroid Build Coastguard Worker            return small_buffers[5], small_buffers[8]
4062*da0073e9SAndroid Build Coastguard Worker
4063*da0073e9SAndroid Build Coastguard Worker        self.checkFunction(foo, [])
4064*da0073e9SAndroid Build Coastguard Worker
4065*da0073e9SAndroid Build Coastguard Worker    def test_middle_allocations_contiguous(self):
4066*da0073e9SAndroid Build Coastguard Worker        def foo():
4067*da0073e9SAndroid Build Coastguard Worker            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4068*da0073e9SAndroid Build Coastguard Worker            return small_buffers[5], small_buffers[6]
4069*da0073e9SAndroid Build Coastguard Worker
4070*da0073e9SAndroid Build Coastguard Worker        self.checkFunction(foo, [])
4071*da0073e9SAndroid Build Coastguard Worker
4072*da0073e9SAndroid Build Coastguard Worker    def test_additional_free_following_checkpoint(self):
4073*da0073e9SAndroid Build Coastguard Worker        def foo():
4074*da0073e9SAndroid Build Coastguard Worker            return (int8_cuda(MIN_BLOCK_SIZE),)
4075*da0073e9SAndroid Build Coastguard Worker
4076*da0073e9SAndroid Build Coastguard Worker        def foo2():
4077*da0073e9SAndroid Build Coastguard Worker            return (int8_cuda(MIN_BLOCK_SIZE),)
4078*da0073e9SAndroid Build Coastguard Worker
4079*da0073e9SAndroid Build Coastguard Worker        graph, outputs = cudagraphify(foo, [])
4080*da0073e9SAndroid Build Coastguard Worker        pool_id = graph.pool()
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Worker        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4083*da0073e9SAndroid Build Coastguard Worker
4084*da0073e9SAndroid Build Coastguard Worker        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4085*da0073e9SAndroid Build Coastguard Worker
4086*da0073e9SAndroid Build Coastguard Worker        graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4087*da0073e9SAndroid Build Coastguard Worker
4088*da0073e9SAndroid Build Coastguard Worker        self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
4089*da0073e9SAndroid Build Coastguard Worker
4090*da0073e9SAndroid Build Coastguard Worker        del outputs2
4091*da0073e9SAndroid Build Coastguard Worker
4092*da0073e9SAndroid Build Coastguard Worker        self.checkCheckpointedState(
4093*da0073e9SAndroid Build Coastguard Worker            segments_before_checkpoint, get_cudagraph_segments(pool_id)
4094*da0073e9SAndroid Build Coastguard Worker        )
4095*da0073e9SAndroid Build Coastguard Worker
4096*da0073e9SAndroid Build Coastguard Worker    # TODO: re-enable
4097*da0073e9SAndroid Build Coastguard Worker    # def test_additional_free_error(self):
4098*da0073e9SAndroid Build Coastguard Worker    #     def foo():
4099*da0073e9SAndroid Build Coastguard Worker    #         return int8_cuda(MIN_BLOCK_SIZE),
4100*da0073e9SAndroid Build Coastguard Worker
4101*da0073e9SAndroid Build Coastguard Worker    #     def foo2():
4102*da0073e9SAndroid Build Coastguard Worker    #         return int8_cuda(MIN_BLOCK_SIZE),
4103*da0073e9SAndroid Build Coastguard Worker
4104*da0073e9SAndroid Build Coastguard Worker    #     graph, outputs = cudagraphify(foo, [])
4105*da0073e9SAndroid Build Coastguard Worker    #     pool_id = graph.pool()
4106*da0073e9SAndroid Build Coastguard Worker
4107*da0073e9SAndroid Build Coastguard Worker    #     segments_before_checkpoint = get_cudagraph_segments(pool_id)
4108*da0073e9SAndroid Build Coastguard Worker
4109*da0073e9SAndroid Build Coastguard Worker    #     state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4110*da0073e9SAndroid Build Coastguard Worker
4111*da0073e9SAndroid Build Coastguard Worker    # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4112*da0073e9SAndroid Build Coastguard Worker    # with self.assertRaisesRegex(Exception, "being manually freed must be passed"):
4113*da0073e9SAndroid Build Coastguard Worker    #     self.setCheckpointPoolState(outputs[0].device.index, state, [], [])
4114*da0073e9SAndroid Build Coastguard Worker
4115*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dies_after_checkpoint(self):
4116*da0073e9SAndroid Build Coastguard Worker        def foo():
4117*da0073e9SAndroid Build Coastguard Worker            return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE)
4118*da0073e9SAndroid Build Coastguard Worker
4119*da0073e9SAndroid Build Coastguard Worker        graph, outputs = cudagraphify(foo, [])
4120*da0073e9SAndroid Build Coastguard Worker        pool_id = graph.pool()
4121*da0073e9SAndroid Build Coastguard Worker        device = outputs[0].device.index
4122*da0073e9SAndroid Build Coastguard Worker
4123*da0073e9SAndroid Build Coastguard Worker        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4124*da0073e9SAndroid Build Coastguard Worker        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4125*da0073e9SAndroid Build Coastguard Worker
4126*da0073e9SAndroid Build Coastguard Worker        output_data_ptrs = [output.data_ptr() for output in outputs]
4127*da0073e9SAndroid Build Coastguard Worker
4128*da0073e9SAndroid Build Coastguard Worker        del outputs
4129*da0073e9SAndroid Build Coastguard Worker
4130*da0073e9SAndroid Build Coastguard Worker        self.setCheckpointPoolState(device, state, [], [])
4131*da0073e9SAndroid Build Coastguard Worker
4132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 2)
4133*da0073e9SAndroid Build Coastguard Worker        torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0])
4134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 1)
4135*da0073e9SAndroid Build Coastguard Worker        torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1])
4136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 0)
4137*da0073e9SAndroid Build Coastguard Worker
4138*da0073e9SAndroid Build Coastguard Worker    def test_assigning_back_deleter_fns_to_tensor(self):
4139*da0073e9SAndroid Build Coastguard Worker        def foo(x):
4140*da0073e9SAndroid Build Coastguard Worker            return (
4141*da0073e9SAndroid Build Coastguard Worker                int8_cuda(SMALL_BUFFER) + x,
4142*da0073e9SAndroid Build Coastguard Worker                int8_cuda(SMALL_BUFFER) + x,
4143*da0073e9SAndroid Build Coastguard Worker                int8_cuda(LARGE_BUFFER) + x,
4144*da0073e9SAndroid Build Coastguard Worker            )
4145*da0073e9SAndroid Build Coastguard Worker
4146*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([1], device="cuda")
4147*da0073e9SAndroid Build Coastguard Worker        graph, outputs = cudagraphify(foo, [inp])
4148*da0073e9SAndroid Build Coastguard Worker        pool_id = graph.pool()
4149*da0073e9SAndroid Build Coastguard Worker        graph.replay()
4150*da0073e9SAndroid Build Coastguard Worker
4151*da0073e9SAndroid Build Coastguard Worker        device = outputs[0].device.index
4152*da0073e9SAndroid Build Coastguard Worker
4153*da0073e9SAndroid Build Coastguard Worker        for i in range(len(outputs)):
4154*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(outputs[i].mean(dtype=torch.float) == 2)
4155*da0073e9SAndroid Build Coastguard Worker
4156*da0073e9SAndroid Build Coastguard Worker        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4157*da0073e9SAndroid Build Coastguard Worker
4158*da0073e9SAndroid Build Coastguard Worker        output_ptrs = [output.untyped_storage().data_ptr() for output in outputs]
4159*da0073e9SAndroid Build Coastguard Worker        ten_metadata = [tensor_metadata(t) for t in outputs]
4160*da0073e9SAndroid Build Coastguard Worker
4161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 3)
4162*da0073e9SAndroid Build Coastguard Worker
4163*da0073e9SAndroid Build Coastguard Worker        del outputs
4164*da0073e9SAndroid Build Coastguard Worker
4165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 0)
4166*da0073e9SAndroid Build Coastguard Worker
4167*da0073e9SAndroid Build Coastguard Worker        reconstructed_tensors = [
4168*da0073e9SAndroid Build Coastguard Worker            reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata
4169*da0073e9SAndroid Build Coastguard Worker        ]
4170*da0073e9SAndroid Build Coastguard Worker
4171*da0073e9SAndroid Build Coastguard Worker        for i in range(len(reconstructed_tensors)):
4172*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2)
4173*da0073e9SAndroid Build Coastguard Worker
4174*da0073e9SAndroid Build Coastguard Worker        inp.add_(1)
4175*da0073e9SAndroid Build Coastguard Worker        graph.replay()
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker        for i in range(len(reconstructed_tensors)):
4178*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3)
4179*da0073e9SAndroid Build Coastguard Worker
4180*da0073e9SAndroid Build Coastguard Worker        self.setCheckpointPoolState(
4181*da0073e9SAndroid Build Coastguard Worker            device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]]
4182*da0073e9SAndroid Build Coastguard Worker        )
4183*da0073e9SAndroid Build Coastguard Worker
4184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 3)
4185*da0073e9SAndroid Build Coastguard Worker
4186*da0073e9SAndroid Build Coastguard Worker        reconstructed_tensors[0] = None
4187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 2)
4188*da0073e9SAndroid Build Coastguard Worker
4189*da0073e9SAndroid Build Coastguard Worker        reconstructed_tensors[1] = None
4190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 1)
4191*da0073e9SAndroid Build Coastguard Worker
4192*da0073e9SAndroid Build Coastguard Worker        # should not change, we did not pass it in to swap data ptrs
4193*da0073e9SAndroid Build Coastguard Worker        reconstructed_tensors[2] = None
4194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 1)
4195*da0073e9SAndroid Build Coastguard Worker
4196*da0073e9SAndroid Build Coastguard Worker        torch._C._cuda_cudaCachingAllocator_raw_delete(output_ptrs[2])
4197*da0073e9SAndroid Build Coastguard Worker
4198*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(live_blocks(pool_id), 0)
4199*da0073e9SAndroid Build Coastguard Worker
4200*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
4201*da0073e9SAndroid Build Coastguard Worker    def test_resnet(self):
4202*da0073e9SAndroid Build Coastguard Worker        import torchvision
4203*da0073e9SAndroid Build Coastguard Worker
4204*da0073e9SAndroid Build Coastguard Worker        m = torchvision.models.resnet50()
4205*da0073e9SAndroid Build Coastguard Worker        m.eval()
4206*da0073e9SAndroid Build Coastguard Worker        m = m.cuda()
4207*da0073e9SAndroid Build Coastguard Worker
4208*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([1, 3, 255, 255], device="cuda")
4209*da0073e9SAndroid Build Coastguard Worker        self.checkFunction(m, [inp])
4210*da0073e9SAndroid Build Coastguard Worker
4211*da0073e9SAndroid Build Coastguard Worker    def test_check_pool_live_allocations(self):
4212*da0073e9SAndroid Build Coastguard Worker        def foo():
4213*da0073e9SAndroid Build Coastguard Worker            return torch.ones([4], device="cuda")
4214*da0073e9SAndroid Build Coastguard Worker
4215*da0073e9SAndroid Build Coastguard Worker        pool = torch.cuda.graph_pool_handle()
4216*da0073e9SAndroid Build Coastguard Worker        graph, outputs = cudagraphify(foo, [], pool=pool)
4217*da0073e9SAndroid Build Coastguard Worker
4218*da0073e9SAndroid Build Coastguard Worker        index = outputs[0].device.index
4219*da0073e9SAndroid Build Coastguard Worker
4220*da0073e9SAndroid Build Coastguard Worker        def check(live_dps):
4221*da0073e9SAndroid Build Coastguard Worker            return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps)
4222*da0073e9SAndroid Build Coastguard Worker
4223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check({outputs[0].data_ptr()}))
4224*da0073e9SAndroid Build Coastguard Worker
4225*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(check({outputs[0].data_ptr(), 0}))
4226*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(check(set()))
4227*da0073e9SAndroid Build Coastguard Worker
4228*da0073e9SAndroid Build Coastguard Worker        del outputs
4229*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(set()))
4230*da0073e9SAndroid Build Coastguard Worker
4231*da0073e9SAndroid Build Coastguard Worker    def test_allocate_in_thread_to_pool(self):
4232*da0073e9SAndroid Build Coastguard Worker        def foo():
4233*da0073e9SAndroid Build Coastguard Worker            return torch.rand([4], device="cuda")
4234*da0073e9SAndroid Build Coastguard Worker
4235*da0073e9SAndroid Build Coastguard Worker        pool = torch.cuda.graph_pool_handle()
4236*da0073e9SAndroid Build Coastguard Worker        graph, outputs = cudagraphify(foo, [], pool=pool)
4237*da0073e9SAndroid Build Coastguard Worker        device = outputs[0].device.index
4238*da0073e9SAndroid Build Coastguard Worker        del outputs
4239*da0073e9SAndroid Build Coastguard Worker
4240*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
4241*da0073e9SAndroid Build Coastguard Worker        def _use_cuda_memory_pool_manager(device, mem_pool):
4242*da0073e9SAndroid Build Coastguard Worker            """
4243*da0073e9SAndroid Build Coastguard Worker            Context manager to use cuda graph pool for new allocations. If you use this manager
4244*da0073e9SAndroid Build Coastguard Worker            all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
4245*da0073e9SAndroid Build Coastguard Worker            existing_graph should already have been used in a capture, and the mem_pool must already exist.
4246*da0073e9SAndroid Build Coastguard Worker            """
4247*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
4248*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.Stream()
4249*da0073e9SAndroid Build Coastguard Worker            stream.wait_stream(torch.cuda.current_stream())
4250*da0073e9SAndroid Build Coastguard Worker            stream_context = torch.cuda.stream(stream)
4251*da0073e9SAndroid Build Coastguard Worker            stream_context.__enter__()
4252*da0073e9SAndroid Build Coastguard Worker            torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
4253*da0073e9SAndroid Build Coastguard Worker            try:
4254*da0073e9SAndroid Build Coastguard Worker                yield
4255*da0073e9SAndroid Build Coastguard Worker            finally:
4256*da0073e9SAndroid Build Coastguard Worker                torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
4257*da0073e9SAndroid Build Coastguard Worker                torch._C._cuda_releasePool(device, mem_pool)
4258*da0073e9SAndroid Build Coastguard Worker                stream_context.__exit__(None, None, None)
4259*da0073e9SAndroid Build Coastguard Worker
4260*da0073e9SAndroid Build Coastguard Worker        segments = get_cudagraph_segments(pool)
4261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(get_cudagraph_segments(pool)), 1)
4262*da0073e9SAndroid Build Coastguard Worker
4263*da0073e9SAndroid Build Coastguard Worker        def use_pool():
4264*da0073e9SAndroid Build Coastguard Worker            def alloc_three():
4265*da0073e9SAndroid Build Coastguard Worker                a = int8_cuda(LARGE_BUFFER)
4266*da0073e9SAndroid Build Coastguard Worker                b = int8_cuda(LARGE_BUFFER)
4267*da0073e9SAndroid Build Coastguard Worker                c = a + b
4268*da0073e9SAndroid Build Coastguard Worker
4269*da0073e9SAndroid Build Coastguard Worker            with _use_cuda_memory_pool_manager(device, pool):
4270*da0073e9SAndroid Build Coastguard Worker                # three allocations
4271*da0073e9SAndroid Build Coastguard Worker                for _ in range(10):
4272*da0073e9SAndroid Build Coastguard Worker                    alloc_three()
4273*da0073e9SAndroid Build Coastguard Worker
4274*da0073e9SAndroid Build Coastguard Worker            # three more allocations not in pool
4275*da0073e9SAndroid Build Coastguard Worker            alloc_three()
4276*da0073e9SAndroid Build Coastguard Worker
4277*da0073e9SAndroid Build Coastguard Worker        def no_pool():
4278*da0073e9SAndroid Build Coastguard Worker            # two allocations
4279*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
4280*da0073e9SAndroid Build Coastguard Worker                a = int8_cuda(LARGE_BUFFER)
4281*da0073e9SAndroid Build Coastguard Worker                b = int8_cuda(LARGE_BUFFER)
4282*da0073e9SAndroid Build Coastguard Worker                del a, b
4283*da0073e9SAndroid Build Coastguard Worker
4284*da0073e9SAndroid Build Coastguard Worker        graph_thread = threading.Thread(target=use_pool)
4285*da0073e9SAndroid Build Coastguard Worker        no_graph_thread = threading.Thread(target=no_pool)
4286*da0073e9SAndroid Build Coastguard Worker        graph_thread.start()
4287*da0073e9SAndroid Build Coastguard Worker        no_graph_thread.start()
4288*da0073e9SAndroid Build Coastguard Worker
4289*da0073e9SAndroid Build Coastguard Worker        graph_thread.join()
4290*da0073e9SAndroid Build Coastguard Worker        no_graph_thread.join()
4291*da0073e9SAndroid Build Coastguard Worker
4292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4293*da0073e9SAndroid Build Coastguard Worker            len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4
4294*da0073e9SAndroid Build Coastguard Worker        )
4295*da0073e9SAndroid Build Coastguard Worker
4296*da0073e9SAndroid Build Coastguard Worker        del graph
4297*da0073e9SAndroid Build Coastguard Worker
4298*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
4299*da0073e9SAndroid Build Coastguard Worker        gc.collect()
4300*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
4301*da0073e9SAndroid Build Coastguard Worker
4302*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(get_cudagraph_segments(pool)), 0)
4303*da0073e9SAndroid Build Coastguard Worker
4304*da0073e9SAndroid Build Coastguard Worker    def test_no_triton_on_import(self):
4305*da0073e9SAndroid Build Coastguard Worker        """Test that Trition is not imported on first GPU use"""
4306*da0073e9SAndroid Build Coastguard Worker        script = "import sys; import torch; torch.rand(2, device='cuda'); print('triton' in sys.modules)"
4307*da0073e9SAndroid Build Coastguard Worker
4308*da0073e9SAndroid Build Coastguard Worker        rc = (
4309*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
4310*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-c", script],
4311*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
4312*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
4313*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),
4314*da0073e9SAndroid Build Coastguard Worker            )
4315*da0073e9SAndroid Build Coastguard Worker            .strip()
4316*da0073e9SAndroid Build Coastguard Worker            .decode("ascii")
4317*da0073e9SAndroid Build Coastguard Worker        )
4318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rc, "False", "Triton was imported when importing torch!")
4319*da0073e9SAndroid Build Coastguard Worker
4320*da0073e9SAndroid Build Coastguard Worker
4321*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4322*da0073e9SAndroid Build Coastguard Workerclass TestMemPool(TestCase):
4323*da0073e9SAndroid Build Coastguard Worker    def test_mempool_id(self):
4324*da0073e9SAndroid Build Coastguard Worker        pool1 = torch.cuda.graph_pool_handle()
4325*da0073e9SAndroid Build Coastguard Worker        pool2 = torch.cuda.MemPool().id
4326*da0073e9SAndroid Build Coastguard Worker
4327*da0073e9SAndroid Build Coastguard Worker        # first value of id in a user created pool is always zero
4328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pool1[0] == 0, pool2[0] == 0)
4329*da0073e9SAndroid Build Coastguard Worker
4330*da0073e9SAndroid Build Coastguard Worker        # each call to torch.cuda.graph_pool_handle() or torch.cuda.MemPool()
4331*da0073e9SAndroid Build Coastguard Worker        # increments the id
4332*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
4333*da0073e9SAndroid Build Coastguard Worker
4334*da0073e9SAndroid Build Coastguard Worker    def test_mempool_with_allocator(self):
4335*da0073e9SAndroid Build Coastguard Worker        pool = torch.cuda.MemPool()
4336*da0073e9SAndroid Build Coastguard Worker
4337*da0073e9SAndroid Build Coastguard Worker        # MemPool doesn't have an allocator by default
4338*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pool.allocator, None)
4339*da0073e9SAndroid Build Coastguard Worker
4340*da0073e9SAndroid Build Coastguard Worker        from torch.utils.cpp_extension import load_inline
4341*da0073e9SAndroid Build Coastguard Worker
4342*da0073e9SAndroid Build Coastguard Worker        dummy_allocator_source = """
4343*da0073e9SAndroid Build Coastguard Worker        #include <torch/extension.h>
4344*da0073e9SAndroid Build Coastguard Worker        #include <ATen/cuda/Exceptions.h>
4345*da0073e9SAndroid Build Coastguard Worker        #include <cuda_runtime_api.h>
4346*da0073e9SAndroid Build Coastguard Worker
4347*da0073e9SAndroid Build Coastguard Worker        extern "C" {
4348*da0073e9SAndroid Build Coastguard Worker          C10_EXPORT int called_dummy_alloc = 0;
4349*da0073e9SAndroid Build Coastguard Worker          C10_EXPORT int called_dummy_free = 0;
4350*da0073e9SAndroid Build Coastguard Worker
4351*da0073e9SAndroid Build Coastguard Worker          // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
4352*da0073e9SAndroid Build Coastguard Worker          C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
4353*da0073e9SAndroid Build Coastguard Worker            called_dummy_alloc = 123;
4354*da0073e9SAndroid Build Coastguard Worker            void* ptr;
4355*da0073e9SAndroid Build Coastguard Worker            C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
4356*da0073e9SAndroid Build Coastguard Worker            return ptr;
4357*da0073e9SAndroid Build Coastguard Worker          }
4358*da0073e9SAndroid Build Coastguard Worker
4359*da0073e9SAndroid Build Coastguard Worker          C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
4360*da0073e9SAndroid Build Coastguard Worker            called_dummy_free = 321;
4361*da0073e9SAndroid Build Coastguard Worker            C10_CUDA_CHECK(cudaFree(ptr));
4362*da0073e9SAndroid Build Coastguard Worker          }
4363*da0073e9SAndroid Build Coastguard Worker        }
4364*da0073e9SAndroid Build Coastguard Worker        """
4365*da0073e9SAndroid Build Coastguard Worker        dummy_allocator_libname = "dummy_allocator"
4366*da0073e9SAndroid Build Coastguard Worker        dummy_allocator = load_inline(
4367*da0073e9SAndroid Build Coastguard Worker            name=dummy_allocator_libname,
4368*da0073e9SAndroid Build Coastguard Worker            cpp_sources=dummy_allocator_source,
4369*da0073e9SAndroid Build Coastguard Worker            is_python_module=False,
4370*da0073e9SAndroid Build Coastguard Worker            keep_intermediates=False,
4371*da0073e9SAndroid Build Coastguard Worker            verbose=True,
4372*da0073e9SAndroid Build Coastguard Worker            with_cuda=True,
4373*da0073e9SAndroid Build Coastguard Worker        )
4374*da0073e9SAndroid Build Coastguard Worker        allocator = torch.cuda.memory.CUDAPluggableAllocator(
4375*da0073e9SAndroid Build Coastguard Worker            dummy_allocator,
4376*da0073e9SAndroid Build Coastguard Worker            "dummy_alloc",
4377*da0073e9SAndroid Build Coastguard Worker            "dummy_free",
4378*da0073e9SAndroid Build Coastguard Worker        )
4379*da0073e9SAndroid Build Coastguard Worker        pool = torch.cuda.MemPool(allocator.allocator())
4380*da0073e9SAndroid Build Coastguard Worker
4381*da0073e9SAndroid Build Coastguard Worker        # pool should point to the same allocator as the one passed into it
4382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(allocator.allocator(), pool.allocator)
4383*da0073e9SAndroid Build Coastguard Worker
4384*da0073e9SAndroid Build Coastguard Worker        # no allocations happened yet, so called_dummy_alloc should be 0
4385*da0073e9SAndroid Build Coastguard Worker        alloc_lib = ctypes.CDLL(dummy_allocator)
4386*da0073e9SAndroid Build Coastguard Worker        called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
4387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called_dummy_alloc.value, 0)
4388*da0073e9SAndroid Build Coastguard Worker
4389*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.use_mem_pool(pool):
4390*da0073e9SAndroid Build Coastguard Worker            out = torch.randn(1, device="cuda")
4391*da0073e9SAndroid Build Coastguard Worker
4392*da0073e9SAndroid Build Coastguard Worker        # called_dummy_alloc should be 123 if dummy_alloc was used to allocate
4393*da0073e9SAndroid Build Coastguard Worker        # out tensor
4394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called_dummy_alloc.value, 123)
4395*da0073e9SAndroid Build Coastguard Worker
4396*da0073e9SAndroid Build Coastguard Worker    def test_mempool_context(self):
4397*da0073e9SAndroid Build Coastguard Worker        active_pool = torch.cuda.MemPoolContext.active_pool()
4398*da0073e9SAndroid Build Coastguard Worker
4399*da0073e9SAndroid Build Coastguard Worker        # there is no active pool if none was made active
4400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(active_pool, None)
4401*da0073e9SAndroid Build Coastguard Worker
4402*da0073e9SAndroid Build Coastguard Worker        pool = torch.cuda.MemPool()
4403*da0073e9SAndroid Build Coastguard Worker        ctx = torch.cuda.MemPoolContext(pool)
4404*da0073e9SAndroid Build Coastguard Worker        active_pool = torch.cuda.MemPoolContext.active_pool()
4405*da0073e9SAndroid Build Coastguard Worker
4406*da0073e9SAndroid Build Coastguard Worker        # pool was made active
4407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(active_pool, pool)
4408*da0073e9SAndroid Build Coastguard Worker
4409*da0073e9SAndroid Build Coastguard Worker        del ctx
4410*da0073e9SAndroid Build Coastguard Worker        active_pool = torch.cuda.MemPoolContext.active_pool()
4411*da0073e9SAndroid Build Coastguard Worker
4412*da0073e9SAndroid Build Coastguard Worker        # ctx was deleted, so active pool is the previous one
4413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(active_pool, None)
4414*da0073e9SAndroid Build Coastguard Worker
4415*da0073e9SAndroid Build Coastguard Worker    def test_mempool_multithread(self):
4416*da0073e9SAndroid Build Coastguard Worker        pool_ids = []
4417*da0073e9SAndroid Build Coastguard Worker        active_pool_ids = []
4418*da0073e9SAndroid Build Coastguard Worker
4419*da0073e9SAndroid Build Coastguard Worker        def create_mempool_and_make_active():
4420*da0073e9SAndroid Build Coastguard Worker            pool = torch.cuda.MemPool()
4421*da0073e9SAndroid Build Coastguard Worker            pool_ids.extend([pool.id])
4422*da0073e9SAndroid Build Coastguard Worker
4423*da0073e9SAndroid Build Coastguard Worker            ctx = torch.cuda.MemPoolContext(pool)
4424*da0073e9SAndroid Build Coastguard Worker            active_pool = torch.cuda.MemPoolContext.active_pool()
4425*da0073e9SAndroid Build Coastguard Worker            active_pool_ids.extend([active_pool.id])
4426*da0073e9SAndroid Build Coastguard Worker            del ctx
4427*da0073e9SAndroid Build Coastguard Worker
4428*da0073e9SAndroid Build Coastguard Worker        num_threads = 4
4429*da0073e9SAndroid Build Coastguard Worker        threads = [
4430*da0073e9SAndroid Build Coastguard Worker            threading.Thread(target=create_mempool_and_make_active)
4431*da0073e9SAndroid Build Coastguard Worker            for t in range(num_threads)
4432*da0073e9SAndroid Build Coastguard Worker        ]
4433*da0073e9SAndroid Build Coastguard Worker        for thread in threads:
4434*da0073e9SAndroid Build Coastguard Worker            thread.start()
4435*da0073e9SAndroid Build Coastguard Worker        for thread in threads:
4436*da0073e9SAndroid Build Coastguard Worker            thread.join()
4437*da0073e9SAndroid Build Coastguard Worker
4438*da0073e9SAndroid Build Coastguard Worker        # each thread should create a unique mempool, since
4439*da0073e9SAndroid Build Coastguard Worker        # mempool id creation is atomic
4440*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(set(pool_ids)), 4)
4441*da0073e9SAndroid Build Coastguard Worker
4442*da0073e9SAndroid Build Coastguard Worker        # each thread should have different active mempool, since
4443*da0073e9SAndroid Build Coastguard Worker        # the pointer to the mempool is thread local
4444*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(set(active_pool_ids)), 4)
4445*da0073e9SAndroid Build Coastguard Worker
4446*da0073e9SAndroid Build Coastguard Worker
4447*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4448*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest
4449*da0073e9SAndroid Build Coastguard Workerclass TestCudaOptims(TestCase):
4450*da0073e9SAndroid Build Coastguard Worker    # These tests will be instantiate with instantiate_device_type_tests
4451*da0073e9SAndroid Build Coastguard Worker    # to apply the new OptimizerInfo structure.
4452*da0073e9SAndroid Build Coastguard Worker
4453*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4454*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
4455*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
4456*da0073e9SAndroid Build Coastguard Worker    )
4457*da0073e9SAndroid Build Coastguard Worker    @optims(
4458*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if optim.has_capturable_arg],
4459*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
4460*da0073e9SAndroid Build Coastguard Worker    )
4461*da0073e9SAndroid Build Coastguard Worker    def test_graph_optims(self, device, dtype, optim_info):
4462*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
4463*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
4464*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
4465*da0073e9SAndroid Build Coastguard Worker        )
4466*da0073e9SAndroid Build Coastguard Worker
4467*da0073e9SAndroid Build Coastguard Worker        steps_warmup = 3
4468*da0073e9SAndroid Build Coastguard Worker        steps_train = 2
4469*da0073e9SAndroid Build Coastguard Worker
4470*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
4471*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
4472*da0073e9SAndroid Build Coastguard Worker
4473*da0073e9SAndroid Build Coastguard Worker            # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
4474*da0073e9SAndroid Build Coastguard Worker            # and torch.optim.adamw
4475*da0073e9SAndroid Build Coastguard Worker            kwargs["lr"] = 0.1
4476*da0073e9SAndroid Build Coastguard Worker
4477*da0073e9SAndroid Build Coastguard Worker            for actually_do_graphs in (True, False):
4478*da0073e9SAndroid Build Coastguard Worker                params = [
4479*da0073e9SAndroid Build Coastguard Worker                    torch.randn((i + 5, i + 5), device=device) for i in range(2)
4480*da0073e9SAndroid Build Coastguard Worker                ] + [torch.randn((), device=device)]
4481*da0073e9SAndroid Build Coastguard Worker                params_control = [p.clone().requires_grad_() for p in params]
4482*da0073e9SAndroid Build Coastguard Worker                params_graphed = [p.clone().requires_grad_() for p in params]
4483*da0073e9SAndroid Build Coastguard Worker
4484*da0073e9SAndroid Build Coastguard Worker                grads = [
4485*da0073e9SAndroid Build Coastguard Worker                    [torch.randn_like(p) for p in params]
4486*da0073e9SAndroid Build Coastguard Worker                    for _ in range(steps_warmup + steps_train)
4487*da0073e9SAndroid Build Coastguard Worker                ]
4488*da0073e9SAndroid Build Coastguard Worker
4489*da0073e9SAndroid Build Coastguard Worker                # Control (capturable=False)
4490*da0073e9SAndroid Build Coastguard Worker                kwargs["capturable"] = False
4491*da0073e9SAndroid Build Coastguard Worker
4492*da0073e9SAndroid Build Coastguard Worker                opt = optim_cls(params_control, **kwargs)
4493*da0073e9SAndroid Build Coastguard Worker                for i in range(steps_warmup + steps_train):
4494*da0073e9SAndroid Build Coastguard Worker                    for j, p in enumerate(params_control):
4495*da0073e9SAndroid Build Coastguard Worker                        p.grad = grads[i][j]
4496*da0073e9SAndroid Build Coastguard Worker                    opt.step()
4497*da0073e9SAndroid Build Coastguard Worker
4498*da0073e9SAndroid Build Coastguard Worker                # capturable=True
4499*da0073e9SAndroid Build Coastguard Worker                kwargs["capturable"] = True
4500*da0073e9SAndroid Build Coastguard Worker                opt = optim_cls(params_graphed, **kwargs)
4501*da0073e9SAndroid Build Coastguard Worker
4502*da0073e9SAndroid Build Coastguard Worker                for i in range(steps_warmup):
4503*da0073e9SAndroid Build Coastguard Worker                    for j, p in enumerate(params_graphed):
4504*da0073e9SAndroid Build Coastguard Worker                        p.grad = grads[i][j]
4505*da0073e9SAndroid Build Coastguard Worker                    opt.step()
4506*da0073e9SAndroid Build Coastguard Worker
4507*da0073e9SAndroid Build Coastguard Worker                if actually_do_graphs:
4508*da0073e9SAndroid Build Coastguard Worker                    g = torch.cuda.CUDAGraph()
4509*da0073e9SAndroid Build Coastguard Worker                    with torch.cuda.graph(g):
4510*da0073e9SAndroid Build Coastguard Worker                        opt.step()
4511*da0073e9SAndroid Build Coastguard Worker
4512*da0073e9SAndroid Build Coastguard Worker                for i in range(steps_train):
4513*da0073e9SAndroid Build Coastguard Worker                    if actually_do_graphs:
4514*da0073e9SAndroid Build Coastguard Worker                        for j, p in enumerate(params_graphed):
4515*da0073e9SAndroid Build Coastguard Worker                            p.grad.copy_(grads[i + steps_warmup][j])
4516*da0073e9SAndroid Build Coastguard Worker                        g.replay()
4517*da0073e9SAndroid Build Coastguard Worker                    else:
4518*da0073e9SAndroid Build Coastguard Worker                        # Passing capturable=True to the constructor and running without graphs should still be
4519*da0073e9SAndroid Build Coastguard Worker                        # numerically correct, even if it's not ideal for performance.
4520*da0073e9SAndroid Build Coastguard Worker                        for j, p in enumerate(params_graphed):
4521*da0073e9SAndroid Build Coastguard Worker                            p.grad = grads[i + steps_warmup][j]
4522*da0073e9SAndroid Build Coastguard Worker                        opt.step()
4523*da0073e9SAndroid Build Coastguard Worker
4524*da0073e9SAndroid Build Coastguard Worker                for p_control, p_graphed in zip(params_control, params_graphed):
4525*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(p_control, p_graphed)
4526*da0073e9SAndroid Build Coastguard Worker
4527*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4528*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
4529*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
4530*da0073e9SAndroid Build Coastguard Worker    )
4531*da0073e9SAndroid Build Coastguard Worker    @optims(
4532*da0073e9SAndroid Build Coastguard Worker        [
4533*da0073e9SAndroid Build Coastguard Worker            optim
4534*da0073e9SAndroid Build Coastguard Worker            for optim in optim_db
4535*da0073e9SAndroid Build Coastguard Worker            if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on
4536*da0073e9SAndroid Build Coastguard Worker        ],
4537*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
4538*da0073e9SAndroid Build Coastguard Worker    )
4539*da0073e9SAndroid Build Coastguard Worker    def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
4540*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
4541*da0073e9SAndroid Build Coastguard Worker
4542*da0073e9SAndroid Build Coastguard Worker        steps_warmup = 3
4543*da0073e9SAndroid Build Coastguard Worker        steps_train = 2
4544*da0073e9SAndroid Build Coastguard Worker
4545*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
4546*da0073e9SAndroid Build Coastguard Worker
4547*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
4548*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
4549*da0073e9SAndroid Build Coastguard Worker            kwargs["fused"] = True
4550*da0073e9SAndroid Build Coastguard Worker
4551*da0073e9SAndroid Build Coastguard Worker            for actually_do_graphs in (
4552*da0073e9SAndroid Build Coastguard Worker                (True, False) if optim_info.has_capturable_arg else (True,)
4553*da0073e9SAndroid Build Coastguard Worker            ):
4554*da0073e9SAndroid Build Coastguard Worker                params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
4555*da0073e9SAndroid Build Coastguard Worker                params_control = [p.clone().requires_grad_() for p in params]
4556*da0073e9SAndroid Build Coastguard Worker                params_graphed = [p.clone().requires_grad_() for p in params]
4557*da0073e9SAndroid Build Coastguard Worker
4558*da0073e9SAndroid Build Coastguard Worker                # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
4559*da0073e9SAndroid Build Coastguard Worker                grads = [
4560*da0073e9SAndroid Build Coastguard Worker                    [torch.randn_like(p) for p in params]
4561*da0073e9SAndroid Build Coastguard Worker                    for _ in range(steps_warmup + steps_train)
4562*da0073e9SAndroid Build Coastguard Worker                ]
4563*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
4564*da0073e9SAndroid Build Coastguard Worker                    grads_control = [[g.clone() for g in gs] for gs in grads]
4565*da0073e9SAndroid Build Coastguard Worker                    grads_graphed = [[g.clone() for g in gs] for gs in grads]
4566*da0073e9SAndroid Build Coastguard Worker
4567*da0073e9SAndroid Build Coastguard Worker                # Gradient Scaler
4568*da0073e9SAndroid Build Coastguard Worker                scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
4569*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
4570*da0073e9SAndroid Build Coastguard Worker                    scaler_for_control._lazy_init_scale_growth_tracker(device)
4571*da0073e9SAndroid Build Coastguard Worker
4572*da0073e9SAndroid Build Coastguard Worker                scaler_for_graphed = torch.cuda.amp.GradScaler()
4573*da0073e9SAndroid Build Coastguard Worker                scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
4574*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
4575*da0073e9SAndroid Build Coastguard Worker                    scaler_for_graphed._lazy_init_scale_growth_tracker(device)
4576*da0073e9SAndroid Build Coastguard Worker
4577*da0073e9SAndroid Build Coastguard Worker                # Control (capturable=False)
4578*da0073e9SAndroid Build Coastguard Worker                if optim_info.has_capturable_arg:
4579*da0073e9SAndroid Build Coastguard Worker                    kwargs["capturable"] = False
4580*da0073e9SAndroid Build Coastguard Worker                opt = optim_cls(params_control, **kwargs)
4581*da0073e9SAndroid Build Coastguard Worker
4582*da0073e9SAndroid Build Coastguard Worker                for i in range(steps_warmup + steps_train):
4583*da0073e9SAndroid Build Coastguard Worker                    for j, p in enumerate(params_control):
4584*da0073e9SAndroid Build Coastguard Worker                        p.grad = grads_control[i][j]
4585*da0073e9SAndroid Build Coastguard Worker                    scaler_for_control.step(opt)
4586*da0073e9SAndroid Build Coastguard Worker                    scaler_for_control.update()
4587*da0073e9SAndroid Build Coastguard Worker
4588*da0073e9SAndroid Build Coastguard Worker                # capturable=True
4589*da0073e9SAndroid Build Coastguard Worker                if optim_info.has_capturable_arg:
4590*da0073e9SAndroid Build Coastguard Worker                    kwargs["capturable"] = True
4591*da0073e9SAndroid Build Coastguard Worker                opt = optim_cls(params_graphed, **kwargs)
4592*da0073e9SAndroid Build Coastguard Worker
4593*da0073e9SAndroid Build Coastguard Worker                for i in range(steps_warmup):
4594*da0073e9SAndroid Build Coastguard Worker                    for j, p in enumerate(params_graphed):
4595*da0073e9SAndroid Build Coastguard Worker                        p.grad = grads_graphed[i][j]
4596*da0073e9SAndroid Build Coastguard Worker                    scaler_for_graphed.step(opt)
4597*da0073e9SAndroid Build Coastguard Worker                    scaler_for_graphed.update()
4598*da0073e9SAndroid Build Coastguard Worker
4599*da0073e9SAndroid Build Coastguard Worker                if actually_do_graphs:
4600*da0073e9SAndroid Build Coastguard Worker                    g = torch.cuda.CUDAGraph()
4601*da0073e9SAndroid Build Coastguard Worker                    with torch.cuda.graph(g):
4602*da0073e9SAndroid Build Coastguard Worker                        scaler_for_graphed.step(opt)
4603*da0073e9SAndroid Build Coastguard Worker                        scaler_for_graphed.update()
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker                for i in range(steps_train):
4606*da0073e9SAndroid Build Coastguard Worker                    if actually_do_graphs:
4607*da0073e9SAndroid Build Coastguard Worker                        for j, p in enumerate(params_graphed):
4608*da0073e9SAndroid Build Coastguard Worker                            p.grad.copy_(grads_graphed[i + steps_warmup][j])
4609*da0073e9SAndroid Build Coastguard Worker                        g.replay()
4610*da0073e9SAndroid Build Coastguard Worker                    else:
4611*da0073e9SAndroid Build Coastguard Worker                        # Passing capturable=True to the constructor and running without graphs should still be
4612*da0073e9SAndroid Build Coastguard Worker                        # numerically correct, even if it's not ideal for performance.
4613*da0073e9SAndroid Build Coastguard Worker                        for j, p in enumerate(params_graphed):
4614*da0073e9SAndroid Build Coastguard Worker                            p.grad = grads_graphed[i + steps_warmup][j]
4615*da0073e9SAndroid Build Coastguard Worker                        scaler_for_graphed.step(opt)
4616*da0073e9SAndroid Build Coastguard Worker                        scaler_for_graphed.update()
4617*da0073e9SAndroid Build Coastguard Worker
4618*da0073e9SAndroid Build Coastguard Worker                for p_control, p_graphed in zip(params_control, params_graphed):
4619*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(p_control, p_graphed)
4620*da0073e9SAndroid Build Coastguard Worker
4621*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4622*da0073e9SAndroid Build Coastguard Worker    @optims(
4623*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "fused" in optim.supported_impls],
4624*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
4625*da0073e9SAndroid Build Coastguard Worker    )
4626*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
4627*da0073e9SAndroid Build Coastguard Worker        device = device.split(":")[0]
4628*da0073e9SAndroid Build Coastguard Worker        if device not in optim_info.supports_fused_on:
4629*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
4630*da0073e9SAndroid Build Coastguard Worker                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
4631*da0073e9SAndroid Build Coastguard Worker            )
4632*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
4633*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
4634*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
4635*da0073e9SAndroid Build Coastguard Worker            for _separate_unscale in (True, False):
4636*da0073e9SAndroid Build Coastguard Worker                kwargs = optim_input.kwargs
4637*da0073e9SAndroid Build Coastguard Worker                kwargs["fused"] = True
4638*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(20)
4639*da0073e9SAndroid Build Coastguard Worker                (
4640*da0073e9SAndroid Build Coastguard Worker                    mod_control,
4641*da0073e9SAndroid Build Coastguard Worker                    mod_scaling,
4642*da0073e9SAndroid Build Coastguard Worker                    opt_control,
4643*da0073e9SAndroid Build Coastguard Worker                    opt_scaling,
4644*da0073e9SAndroid Build Coastguard Worker                    data,
4645*da0073e9SAndroid Build Coastguard Worker                    loss_fn,
4646*da0073e9SAndroid Build Coastguard Worker                    _,
4647*da0073e9SAndroid Build Coastguard Worker                ) = _create_scaling_case(
4648*da0073e9SAndroid Build Coastguard Worker                    optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device
4649*da0073e9SAndroid Build Coastguard Worker                )
4650*da0073e9SAndroid Build Coastguard Worker                optimizer_kwargs = deepcopy(kwargs)
4651*da0073e9SAndroid Build Coastguard Worker                optimizer_kwargs["fused"] = False
4652*da0073e9SAndroid Build Coastguard Worker                if "lr" not in kwargs:
4653*da0073e9SAndroid Build Coastguard Worker                    # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
4654*da0073e9SAndroid Build Coastguard Worker                    optimizer_kwargs["lr"] = 1.0
4655*da0073e9SAndroid Build Coastguard Worker                opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs)
4656*da0073e9SAndroid Build Coastguard Worker                scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
4657*da0073e9SAndroid Build Coastguard Worker                scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
4658*da0073e9SAndroid Build Coastguard Worker                tracker = TensorTracker()
4659*da0073e9SAndroid Build Coastguard Worker                for input, target in data:
4660*da0073e9SAndroid Build Coastguard Worker                    opt_control.zero_grad()
4661*da0073e9SAndroid Build Coastguard Worker                    with torch.autocast(device_type=device, dtype=torch.half):
4662*da0073e9SAndroid Build Coastguard Worker                        output_control = mod_control(input)
4663*da0073e9SAndroid Build Coastguard Worker                        loss_control = loss_fn(output_control, target)
4664*da0073e9SAndroid Build Coastguard Worker                    scaler_control.scale(loss_control).backward()
4665*da0073e9SAndroid Build Coastguard Worker                    scaler_control.step(opt_control)
4666*da0073e9SAndroid Build Coastguard Worker                    scaler_control.update()
4667*da0073e9SAndroid Build Coastguard Worker
4668*da0073e9SAndroid Build Coastguard Worker                    opt_scaling.zero_grad()
4669*da0073e9SAndroid Build Coastguard Worker                    with torch.autocast(device_type=device, dtype=torch.half):
4670*da0073e9SAndroid Build Coastguard Worker                        output_scaling = mod_scaling(input)
4671*da0073e9SAndroid Build Coastguard Worker                        loss_scaling = loss_fn(output_scaling, target)
4672*da0073e9SAndroid Build Coastguard Worker                    scaler_scaling.scale(loss_scaling).backward()
4673*da0073e9SAndroid Build Coastguard Worker                    if _separate_unscale:
4674*da0073e9SAndroid Build Coastguard Worker                        scaler_scaling.unscale_(opt_scaling)
4675*da0073e9SAndroid Build Coastguard Worker                    scaler_scaling.step(opt_scaling)
4676*da0073e9SAndroid Build Coastguard Worker                    scaler_scaling.update()
4677*da0073e9SAndroid Build Coastguard Worker
4678*da0073e9SAndroid Build Coastguard Worker                    tracker.add(loss_control)
4679*da0073e9SAndroid Build Coastguard Worker                    tracker.pop_check_set(loss_scaling, self)
4680*da0073e9SAndroid Build Coastguard Worker                    for param_control, param_scaling in zip(
4681*da0073e9SAndroid Build Coastguard Worker                        mod_control.parameters(), mod_scaling.parameters()
4682*da0073e9SAndroid Build Coastguard Worker                    ):
4683*da0073e9SAndroid Build Coastguard Worker                        tracker.add(param_control.grad)
4684*da0073e9SAndroid Build Coastguard Worker                        tracker.pop_check_set(param_scaling.grad, self)
4685*da0073e9SAndroid Build Coastguard Worker                        tracker.add(param_control)
4686*da0073e9SAndroid Build Coastguard Worker                        tracker.pop_check_set(param_scaling, self)
4687*da0073e9SAndroid Build Coastguard Worker
4688*da0073e9SAndroid Build Coastguard Worker                        state_control, state_scaling = (
4689*da0073e9SAndroid Build Coastguard Worker                            opt_control.state[param_control],
4690*da0073e9SAndroid Build Coastguard Worker                            opt_scaling.state[param_scaling],
4691*da0073e9SAndroid Build Coastguard Worker                        )
4692*da0073e9SAndroid Build Coastguard Worker
4693*da0073e9SAndroid Build Coastguard Worker                        for k in state_control:
4694*da0073e9SAndroid Build Coastguard Worker                            actual = state_scaling[k]
4695*da0073e9SAndroid Build Coastguard Worker                            if k == "step":
4696*da0073e9SAndroid Build Coastguard Worker                                actual = actual.squeeze()
4697*da0073e9SAndroid Build Coastguard Worker                            tracker.add(state_control[k])
4698*da0073e9SAndroid Build Coastguard Worker                            tracker.pop_check_set(actual, self)
4699*da0073e9SAndroid Build Coastguard Worker
4700*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4701*da0073e9SAndroid Build Coastguard Worker    @parametrize("in_place_unscale", [False, True])
4702*da0073e9SAndroid Build Coastguard Worker    @optims(
4703*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "cuda" in optim.supports_fused_on],
4704*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
4705*da0073e9SAndroid Build Coastguard Worker    )
4706*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaler_with_preset_grad_scale(
4707*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, optim_info, in_place_unscale
4708*da0073e9SAndroid Build Coastguard Worker    ):
4709*da0073e9SAndroid Build Coastguard Worker        weight = torch.ones((5, 5), device="cuda", requires_grad=True)
4710*da0073e9SAndroid Build Coastguard Worker        weight.grad = torch.full_like(weight, fill_value=15)
4711*da0073e9SAndroid Build Coastguard Worker        opt = optim_info.optim_cls([weight], lr=0.1, fused=True)
4712*da0073e9SAndroid Build Coastguard Worker        scaler = torch.amp.GradScaler(init_scale=5)
4713*da0073e9SAndroid Build Coastguard Worker
4714*da0073e9SAndroid Build Coastguard Worker        # simulate scaling a loss
4715*da0073e9SAndroid Build Coastguard Worker        scaler.scale(torch.ones(5))
4716*da0073e9SAndroid Build Coastguard Worker
4717*da0073e9SAndroid Build Coastguard Worker        if in_place_unscale:
4718*da0073e9SAndroid Build Coastguard Worker            scaler.unscale_(opt)
4719*da0073e9SAndroid Build Coastguard Worker            # the gradient should have been divided in-place
4720*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight.grad, torch.full_like(weight, fill_value=3))
4721*da0073e9SAndroid Build Coastguard Worker
4722*da0073e9SAndroid Build Coastguard Worker        # the user sets a `grad_scale` value which should be fused with the optimizer step
4723*da0073e9SAndroid Build Coastguard Worker        opt.grad_scale = torch.Tensor([3]).cuda()
4724*da0073e9SAndroid Build Coastguard Worker        scaler.step(opt)
4725*da0073e9SAndroid Build Coastguard Worker
4726*da0073e9SAndroid Build Coastguard Worker        # check that the user's grad_scale was respected (i.e. the gradient was divided by 5 * 3)
4727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weight.grad, torch.full_like(weight, fill_value=1))
4728*da0073e9SAndroid Build Coastguard Worker
4729*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4730*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
4731*da0073e9SAndroid Build Coastguard Worker        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
4732*da0073e9SAndroid Build Coastguard Worker    )
4733*da0073e9SAndroid Build Coastguard Worker    @parametrize("foreach, fused", [(False, False), (True, False), (False, True)])
4734*da0073e9SAndroid Build Coastguard Worker    @optims(
4735*da0073e9SAndroid Build Coastguard Worker        [
4736*da0073e9SAndroid Build Coastguard Worker            optim
4737*da0073e9SAndroid Build Coastguard Worker            for optim in optim_db
4738*da0073e9SAndroid Build Coastguard Worker            if "foreach" in optim.supported_impls and "cuda" in optim.supports_fused_on
4739*da0073e9SAndroid Build Coastguard Worker        ],
4740*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
4741*da0073e9SAndroid Build Coastguard Worker    )
4742*da0073e9SAndroid Build Coastguard Worker    def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused):
4743*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
4744*da0073e9SAndroid Build Coastguard Worker
4745*da0073e9SAndroid Build Coastguard Worker        scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0)
4746*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
4747*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
4748*da0073e9SAndroid Build Coastguard Worker
4749*da0073e9SAndroid Build Coastguard Worker        weight = torch.ones((100,), device="cuda", requires_grad=True)
4750*da0073e9SAndroid Build Coastguard Worker        opt = optim_info.optim_cls([weight], lr=0.1, foreach=foreach, fused=fused)
4751*da0073e9SAndroid Build Coastguard Worker        static_input = torch.ones_like(weight)
4752*da0073e9SAndroid Build Coastguard Worker        static_grad = torch.ones_like(weight)
4753*da0073e9SAndroid Build Coastguard Worker
4754*da0073e9SAndroid Build Coastguard Worker        # warmup
4755*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
4756*da0073e9SAndroid Build Coastguard Worker        s.wait_stream(torch.cuda.current_stream())
4757*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
4758*da0073e9SAndroid Build Coastguard Worker            loss = (weight.half() * static_input).sum()
4759*da0073e9SAndroid Build Coastguard Worker            scaler.scale(loss).backward()
4760*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
4761*da0073e9SAndroid Build Coastguard Worker
4762*da0073e9SAndroid Build Coastguard Worker        opt.zero_grad(set_to_none=True)
4763*da0073e9SAndroid Build Coastguard Worker
4764*da0073e9SAndroid Build Coastguard Worker        # capture
4765*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
4766*da0073e9SAndroid Build Coastguard Worker            g.capture_begin()
4767*da0073e9SAndroid Build Coastguard Worker            loss = (weight.half() * static_input).sum()
4768*da0073e9SAndroid Build Coastguard Worker            scaler.scale(loss).backward()
4769*da0073e9SAndroid Build Coastguard Worker            g.capture_end()
4770*da0073e9SAndroid Build Coastguard Worker
4771*da0073e9SAndroid Build Coastguard Worker        input_vals = [5, 20000, 5, 40000]
4772*da0073e9SAndroid Build Coastguard Worker        # If the scale gets updated properly, these are the scale, growth tracker,
4773*da0073e9SAndroid Build Coastguard Worker        # and grad values we expect.
4774*da0073e9SAndroid Build Coastguard Worker        expected_scales = [4, 2, 2, 1]
4775*da0073e9SAndroid Build Coastguard Worker        expected_growth_trackers = [1, 0, 1, 0]
4776*da0073e9SAndroid Build Coastguard Worker        expected_grad_vals = [5 * 4, float("inf"), 5 * 2, float("inf")]
4777*da0073e9SAndroid Build Coastguard Worker
4778*da0073e9SAndroid Build Coastguard Worker        for data, scale, growth_tracker, grad_val in zip(
4779*da0073e9SAndroid Build Coastguard Worker            input_vals, expected_scales, expected_growth_trackers, expected_grad_vals
4780*da0073e9SAndroid Build Coastguard Worker        ):
4781*da0073e9SAndroid Build Coastguard Worker            static_input.fill_(data)
4782*da0073e9SAndroid Build Coastguard Worker            g.replay()
4783*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight.grad, torch.full_like(weight.grad, grad_val))
4784*da0073e9SAndroid Build Coastguard Worker            scaler.step(opt)
4785*da0073e9SAndroid Build Coastguard Worker            scaler.update()
4786*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scaler._scale, scale)
4787*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scaler._growth_tracker, growth_tracker)
4788*da0073e9SAndroid Build Coastguard Worker
4789*da0073e9SAndroid Build Coastguard Worker
4790*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4791*da0073e9SAndroid Build Coastguard Workerclass TestGDS(TestCase):
4792*da0073e9SAndroid Build Coastguard Worker    def _get_tmp_dir_fs_type(self):
4793*da0073e9SAndroid Build Coastguard Worker        my_path = os.path.realpath("/tmp")
4794*da0073e9SAndroid Build Coastguard Worker        root_type = ""
4795*da0073e9SAndroid Build Coastguard Worker        for part in psutil.disk_partitions():
4796*da0073e9SAndroid Build Coastguard Worker            if part.mountpoint == "/":
4797*da0073e9SAndroid Build Coastguard Worker                root_type = part.fstype
4798*da0073e9SAndroid Build Coastguard Worker                continue
4799*da0073e9SAndroid Build Coastguard Worker            if part.mountpoint == my_path:
4800*da0073e9SAndroid Build Coastguard Worker                return part.fstype
4801*da0073e9SAndroid Build Coastguard Worker        return root_type
4802*da0073e9SAndroid Build Coastguard Worker
4803*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
4804*da0073e9SAndroid Build Coastguard Worker    def test_gds_read_write_tensors(self):
4805*da0073e9SAndroid Build Coastguard Worker        if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
4806*da0073e9SAndroid Build Coastguard Worker            self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
4807*da0073e9SAndroid Build Coastguard Worker        src1 = torch.randn(1024, device="cuda")
4808*da0073e9SAndroid Build Coastguard Worker        src2 = torch.randn(2, 1024, device="cuda")
4809*da0073e9SAndroid Build Coastguard Worker        torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
4810*da0073e9SAndroid Build Coastguard Worker        torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
4811*da0073e9SAndroid Build Coastguard Worker        dest1 = torch.empty(1024, device="cuda")
4812*da0073e9SAndroid Build Coastguard Worker        dest2 = torch.empty(2, 1024, device="cuda")
4813*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as f:
4814*da0073e9SAndroid Build Coastguard Worker            file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
4815*da0073e9SAndroid Build Coastguard Worker            file.save_storage(src1.untyped_storage(), offset=0)
4816*da0073e9SAndroid Build Coastguard Worker            file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
4817*da0073e9SAndroid Build Coastguard Worker            file.load_storage(dest1.untyped_storage(), offset=0)
4818*da0073e9SAndroid Build Coastguard Worker            file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
4819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src1, dest1)
4820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src2, dest2)
4821*da0073e9SAndroid Build Coastguard Worker        torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
4822*da0073e9SAndroid Build Coastguard Worker        torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
4823*da0073e9SAndroid Build Coastguard Worker
4824*da0073e9SAndroid Build Coastguard Worker
4825*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4826*da0073e9SAndroid Build Coastguard Workerclass TestCudaAutocast(TestAutocast):
4827*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
4828*da0073e9SAndroid Build Coastguard Worker        super().setUp()
4829*da0073e9SAndroid Build Coastguard Worker        self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
4830*da0073e9SAndroid Build Coastguard Worker
4831*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
4832*da0073e9SAndroid Build Coastguard Worker        del self.autocast_lists
4833*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
4834*da0073e9SAndroid Build Coastguard Worker
4835*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4836*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_fp16(self):
4837*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4838*da0073e9SAndroid Build Coastguard Worker            for op_with_args in self.autocast_lists.torch_fp16:
4839*da0073e9SAndroid Build Coastguard Worker                skip_test = False
4840*da0073e9SAndroid Build Coastguard Worker                op, args = op_with_args[0], op_with_args[1]
4841*da0073e9SAndroid Build Coastguard Worker                if len(op_with_args) == 3:
4842*da0073e9SAndroid Build Coastguard Worker                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
4843*da0073e9SAndroid Build Coastguard Worker                if not skip_test:
4844*da0073e9SAndroid Build Coastguard Worker                    self._run_autocast_outofplace(
4845*da0073e9SAndroid Build Coastguard Worker                        op, args, torch.float16, device="cuda", amp_dtype=torch.float16
4846*da0073e9SAndroid Build Coastguard Worker                    )
4847*da0073e9SAndroid Build Coastguard Worker
4848*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4849*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_bf16(self):
4850*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4851*da0073e9SAndroid Build Coastguard Worker            for op_with_args in self.autocast_lists.torch_fp16:
4852*da0073e9SAndroid Build Coastguard Worker                skip_test = False
4853*da0073e9SAndroid Build Coastguard Worker                op, args = op_with_args[0], op_with_args[1]
4854*da0073e9SAndroid Build Coastguard Worker                if len(op_with_args) == 3:
4855*da0073e9SAndroid Build Coastguard Worker                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
4856*da0073e9SAndroid Build Coastguard Worker                should_error_from_cudnn = "cudnn" in op and (
4857*da0073e9SAndroid Build Coastguard Worker                    "TORCH_CUDNN_V8_API_DISABLED" in os.environ
4858*da0073e9SAndroid Build Coastguard Worker                    and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
4859*da0073e9SAndroid Build Coastguard Worker                    or torch.cuda.get_device_capability() < (8, 0)
4860*da0073e9SAndroid Build Coastguard Worker                )
4861*da0073e9SAndroid Build Coastguard Worker                should_error_from_not_implemented = should_error_from_cudnn
4862*da0073e9SAndroid Build Coastguard Worker                if not skip_test:
4863*da0073e9SAndroid Build Coastguard Worker                    if should_error_from_not_implemented:
4864*da0073e9SAndroid Build Coastguard Worker                        with self.assertRaises(
4865*da0073e9SAndroid Build Coastguard Worker                            RuntimeError,
4866*da0073e9SAndroid Build Coastguard Worker                            msg=str(op) + " should not be supported for bfloat16!",
4867*da0073e9SAndroid Build Coastguard Worker                        ):
4868*da0073e9SAndroid Build Coastguard Worker                            self._run_autocast_outofplace(
4869*da0073e9SAndroid Build Coastguard Worker                                op, args, torch.bfloat16, device="cuda"
4870*da0073e9SAndroid Build Coastguard Worker                            )
4871*da0073e9SAndroid Build Coastguard Worker                    else:
4872*da0073e9SAndroid Build Coastguard Worker                        if torch.cuda.is_bf16_supported():
4873*da0073e9SAndroid Build Coastguard Worker                            self._run_autocast_outofplace(
4874*da0073e9SAndroid Build Coastguard Worker                                op, args, torch.bfloat16, device="cuda"
4875*da0073e9SAndroid Build Coastguard Worker                            )
4876*da0073e9SAndroid Build Coastguard Worker                        else:
4877*da0073e9SAndroid Build Coastguard Worker                            with self.assertRaisesRegex(
4878*da0073e9SAndroid Build Coastguard Worker                                RuntimeError, "Device does not support bfloat16"
4879*da0073e9SAndroid Build Coastguard Worker                            ):
4880*da0073e9SAndroid Build Coastguard Worker                                self._run_autocast_outofplace(
4881*da0073e9SAndroid Build Coastguard Worker                                    op, args, torch.bfloat16, device="cuda"
4882*da0073e9SAndroid Build Coastguard Worker                                )
4883*da0073e9SAndroid Build Coastguard Worker
4884*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4885*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_fp32(self):
4886*da0073e9SAndroid Build Coastguard Worker        for op_with_args in self.autocast_lists.torch_fp32:
4887*da0073e9SAndroid Build Coastguard Worker            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
4888*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
4889*da0073e9SAndroid Build Coastguard Worker                op,
4890*da0073e9SAndroid Build Coastguard Worker                args,
4891*da0073e9SAndroid Build Coastguard Worker                torch.float32,
4892*da0073e9SAndroid Build Coastguard Worker                device="cuda",
4893*da0073e9SAndroid Build Coastguard Worker                add_kwargs=maybe_kwargs,
4894*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
4895*da0073e9SAndroid Build Coastguard Worker            )
4896*da0073e9SAndroid Build Coastguard Worker
4897*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4898*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_need_autocast_promote(self):
4899*da0073e9SAndroid Build Coastguard Worker        for op, args in self.autocast_lists.torch_need_autocast_promote:
4900*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
4901*da0073e9SAndroid Build Coastguard Worker                op, args, torch.float32, device="cuda", amp_dtype=torch.float16
4902*da0073e9SAndroid Build Coastguard Worker            )
4903*da0073e9SAndroid Build Coastguard Worker
4904*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4905*da0073e9SAndroid Build Coastguard Worker    def test_autocast_torch_expect_builtin_promote(self):
4906*da0073e9SAndroid Build Coastguard Worker        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
4907*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
4908*da0073e9SAndroid Build Coastguard Worker                op,
4909*da0073e9SAndroid Build Coastguard Worker                args,
4910*da0073e9SAndroid Build Coastguard Worker                torch.float32,
4911*da0073e9SAndroid Build Coastguard Worker                device="cuda",
4912*da0073e9SAndroid Build Coastguard Worker                out_type=out_type,
4913*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
4914*da0073e9SAndroid Build Coastguard Worker            )
4915*da0073e9SAndroid Build Coastguard Worker
4916*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4917*da0073e9SAndroid Build Coastguard Worker    def test_autocast_nn_fp16(self):
4918*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4919*da0073e9SAndroid Build Coastguard Worker            for op, args in self.autocast_lists.nn_fp16:
4920*da0073e9SAndroid Build Coastguard Worker                self._run_autocast_outofplace(
4921*da0073e9SAndroid Build Coastguard Worker                    op,
4922*da0073e9SAndroid Build Coastguard Worker                    args,
4923*da0073e9SAndroid Build Coastguard Worker                    torch.float16,
4924*da0073e9SAndroid Build Coastguard Worker                    device="cuda",
4925*da0073e9SAndroid Build Coastguard Worker                    module=torch._C._nn,
4926*da0073e9SAndroid Build Coastguard Worker                    amp_dtype=torch.float16,
4927*da0073e9SAndroid Build Coastguard Worker                )
4928*da0073e9SAndroid Build Coastguard Worker
4929*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4930*da0073e9SAndroid Build Coastguard Worker    def test_autocast_nn_bf16(self):
4931*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4932*da0073e9SAndroid Build Coastguard Worker            for op, args in self.autocast_lists.nn_fp16:
4933*da0073e9SAndroid Build Coastguard Worker                if torch.cuda.is_bf16_supported():
4934*da0073e9SAndroid Build Coastguard Worker                    self._run_autocast_outofplace(
4935*da0073e9SAndroid Build Coastguard Worker                        op, args, torch.bfloat16, device="cuda", module=torch._C._nn
4936*da0073e9SAndroid Build Coastguard Worker                    )
4937*da0073e9SAndroid Build Coastguard Worker                else:
4938*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
4939*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, "Device does not support bfloat16"
4940*da0073e9SAndroid Build Coastguard Worker                    ):
4941*da0073e9SAndroid Build Coastguard Worker                        self._run_autocast_outofplace(
4942*da0073e9SAndroid Build Coastguard Worker                            op, args, torch.bfloat16, device="cuda", module=torch._C._nn
4943*da0073e9SAndroid Build Coastguard Worker                        )
4944*da0073e9SAndroid Build Coastguard Worker
4945*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4946*da0073e9SAndroid Build Coastguard Worker    def test_autocast_nn_fp32(self):
4947*da0073e9SAndroid Build Coastguard Worker        for op, args in self.autocast_lists.nn_fp32:
4948*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
4949*da0073e9SAndroid Build Coastguard Worker                op,
4950*da0073e9SAndroid Build Coastguard Worker                args,
4951*da0073e9SAndroid Build Coastguard Worker                torch.float32,
4952*da0073e9SAndroid Build Coastguard Worker                device="cuda",
4953*da0073e9SAndroid Build Coastguard Worker                module=torch._C._nn,
4954*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
4955*da0073e9SAndroid Build Coastguard Worker            )
4956*da0073e9SAndroid Build Coastguard Worker
4957*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4958*da0073e9SAndroid Build Coastguard Worker    def test_autocast_linalg_fp16(self):
4959*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4960*da0073e9SAndroid Build Coastguard Worker            for op, args in self.autocast_lists.linalg_fp16:
4961*da0073e9SAndroid Build Coastguard Worker                self._run_autocast_outofplace(
4962*da0073e9SAndroid Build Coastguard Worker                    op,
4963*da0073e9SAndroid Build Coastguard Worker                    args,
4964*da0073e9SAndroid Build Coastguard Worker                    torch.float16,
4965*da0073e9SAndroid Build Coastguard Worker                    device="cuda",
4966*da0073e9SAndroid Build Coastguard Worker                    module=torch._C._linalg,
4967*da0073e9SAndroid Build Coastguard Worker                    amp_dtype=torch.float16,
4968*da0073e9SAndroid Build Coastguard Worker                )
4969*da0073e9SAndroid Build Coastguard Worker
4970*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4971*da0073e9SAndroid Build Coastguard Worker    def test_autocast_methods_fp16(self):
4972*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4973*da0073e9SAndroid Build Coastguard Worker            for op, args in self.autocast_lists.methods_fp16:
4974*da0073e9SAndroid Build Coastguard Worker                self._run_autocast_outofplace(
4975*da0073e9SAndroid Build Coastguard Worker                    op,
4976*da0073e9SAndroid Build Coastguard Worker                    args,
4977*da0073e9SAndroid Build Coastguard Worker                    torch.float16,
4978*da0073e9SAndroid Build Coastguard Worker                    device="cuda",
4979*da0073e9SAndroid Build Coastguard Worker                    module=None,
4980*da0073e9SAndroid Build Coastguard Worker                    amp_dtype=torch.float16,
4981*da0073e9SAndroid Build Coastguard Worker                )
4982*da0073e9SAndroid Build Coastguard Worker
4983*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4984*da0073e9SAndroid Build Coastguard Worker    def test_autocast_methods_fp32(self):
4985*da0073e9SAndroid Build Coastguard Worker        for op, args in self.autocast_lists.methods_fp32:
4986*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
4987*da0073e9SAndroid Build Coastguard Worker                op,
4988*da0073e9SAndroid Build Coastguard Worker                args,
4989*da0073e9SAndroid Build Coastguard Worker                torch.float32,
4990*da0073e9SAndroid Build Coastguard Worker                device="cuda",
4991*da0073e9SAndroid Build Coastguard Worker                module=None,
4992*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
4993*da0073e9SAndroid Build Coastguard Worker            )
4994*da0073e9SAndroid Build Coastguard Worker
4995*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4996*da0073e9SAndroid Build Coastguard Worker    def test_autocast_methods_expect_builtin_promote(self):
4997*da0073e9SAndroid Build Coastguard Worker        for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
4998*da0073e9SAndroid Build Coastguard Worker            self._run_autocast_outofplace(
4999*da0073e9SAndroid Build Coastguard Worker                op,
5000*da0073e9SAndroid Build Coastguard Worker                args,
5001*da0073e9SAndroid Build Coastguard Worker                torch.float32,
5002*da0073e9SAndroid Build Coastguard Worker                device="cuda",
5003*da0073e9SAndroid Build Coastguard Worker                module=None,
5004*da0073e9SAndroid Build Coastguard Worker                out_type=out_type,
5005*da0073e9SAndroid Build Coastguard Worker                amp_dtype=torch.float16,
5006*da0073e9SAndroid Build Coastguard Worker            )
5007*da0073e9SAndroid Build Coastguard Worker
5008*da0073e9SAndroid Build Coastguard Worker    def test_autocast_banned(self):
5009*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda"):
5010*da0073e9SAndroid Build Coastguard Worker            for op, args, module in self.autocast_lists.banned:
5011*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
5012*da0073e9SAndroid Build Coastguard Worker                    getattr(module, op)(*args)
5013*da0073e9SAndroid Build Coastguard Worker
5014*da0073e9SAndroid Build Coastguard Worker    def test_autocast_ignored_types(self):
5015*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda"):
5016*da0073e9SAndroid Build Coastguard Worker            for ignore_type in (torch.double, torch.int32):
5017*da0073e9SAndroid Build Coastguard Worker                a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
5018*da0073e9SAndroid Build Coastguard Worker                b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
5019*da0073e9SAndroid Build Coastguard Worker                c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
5020*da0073e9SAndroid Build Coastguard Worker
5021*da0073e9SAndroid Build Coastguard Worker                # Tests if CastPolicy::fp16 ops ignore double and int
5022*da0073e9SAndroid Build Coastguard Worker                # Currently, no ops belonging to this policy support integer inputs.
5023*da0073e9SAndroid Build Coastguard Worker                if ignore_type is torch.double:
5024*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
5025*da0073e9SAndroid Build Coastguard Worker                        torch.mm(a_ignore, c_16)
5026*da0073e9SAndroid Build Coastguard Worker                    with torch.autocast("cuda", enabled=False):
5027*da0073e9SAndroid Build Coastguard Worker                        type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
5028*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
5029*da0073e9SAndroid Build Coastguard Worker                        torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
5030*da0073e9SAndroid Build Coastguard Worker                    )
5031*da0073e9SAndroid Build Coastguard Worker
5032*da0073e9SAndroid Build Coastguard Worker                # Tests if CastPolicy::fp32 ops ignore double and int
5033*da0073e9SAndroid Build Coastguard Worker                with torch.autocast("cuda", enabled=False):
5034*da0073e9SAndroid Build Coastguard Worker                    type_no_autocast = torch.pow(a_ignore, 2.0).dtype
5035*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
5036*da0073e9SAndroid Build Coastguard Worker
5037*da0073e9SAndroid Build Coastguard Worker                # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
5038*da0073e9SAndroid Build Coastguard Worker                with torch.autocast("cuda", enabled=False):
5039*da0073e9SAndroid Build Coastguard Worker                    type_no_autocast = torch.sum(a_ignore).dtype
5040*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
5041*da0073e9SAndroid Build Coastguard Worker
5042*da0073e9SAndroid Build Coastguard Worker                # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
5043*da0073e9SAndroid Build Coastguard Worker                # Currently, no ops belonging to this policy support integer inputs.
5044*da0073e9SAndroid Build Coastguard Worker                if ignore_type is torch.double:
5045*da0073e9SAndroid Build Coastguard Worker                    with torch.autocast("cuda", enabled=False):
5046*da0073e9SAndroid Build Coastguard Worker                        type_no_autocast = torch.norm(a_ignore).dtype
5047*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
5048*da0073e9SAndroid Build Coastguard Worker
5049*da0073e9SAndroid Build Coastguard Worker    def test_autocast_custom_enabled(self):
5050*da0073e9SAndroid Build Coastguard Worker        class MyMM(torch.autograd.Function):
5051*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5052*da0073e9SAndroid Build Coastguard Worker            @torch.amp.custom_fwd(device_type="cuda")
5053*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
5054*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(a.dtype is torch.float32)
5055*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(b.dtype is torch.float32)
5056*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_autocast_enabled())
5057*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(a, b)
5058*da0073e9SAndroid Build Coastguard Worker                return a.mm(b)
5059*da0073e9SAndroid Build Coastguard Worker
5060*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5061*da0073e9SAndroid Build Coastguard Worker            @torch.amp.custom_bwd(device_type="cuda")
5062*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
5063*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_autocast_enabled())
5064*da0073e9SAndroid Build Coastguard Worker                a, b = ctx.saved_tensors
5065*da0073e9SAndroid Build Coastguard Worker                a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
5066*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
5067*da0073e9SAndroid Build Coastguard Worker                return a_grad, b_grad
5068*da0073e9SAndroid Build Coastguard Worker
5069*da0073e9SAndroid Build Coastguard Worker        mymm = MyMM.apply
5070*da0073e9SAndroid Build Coastguard Worker
5071*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
5072*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
5073*da0073e9SAndroid Build Coastguard Worker
5074*da0073e9SAndroid Build Coastguard Worker        dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
5075*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
5076*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.amp.autocast(dtype=dtype):
5077*da0073e9SAndroid Build Coastguard Worker                output = mymm(x, y)
5078*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(output.dtype is dtype)
5079*da0073e9SAndroid Build Coastguard Worker                loss = output.sum()
5080*da0073e9SAndroid Build Coastguard Worker            loss.backward()
5081*da0073e9SAndroid Build Coastguard Worker
5082*da0073e9SAndroid Build Coastguard Worker    def test_autocast_custom_cast_inputs(self):
5083*da0073e9SAndroid Build Coastguard Worker        class MyMM(torch.autograd.Function):
5084*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5085*da0073e9SAndroid Build Coastguard Worker            @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
5086*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, container, expect_type):
5087*da0073e9SAndroid Build Coastguard Worker                b = container[1][0]
5088*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(a.dtype is expect_type)
5089*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(b.dtype is expect_type)
5090*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_autocast_enabled())
5091*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(a, b)
5092*da0073e9SAndroid Build Coastguard Worker                return a.mm(b)
5093*da0073e9SAndroid Build Coastguard Worker
5094*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5095*da0073e9SAndroid Build Coastguard Worker            @torch.amp.custom_bwd(device_type="cuda")
5096*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
5097*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_autocast_enabled())
5098*da0073e9SAndroid Build Coastguard Worker                a, b = ctx.saved_tensors
5099*da0073e9SAndroid Build Coastguard Worker                return grad.mm(b.t()), None, None
5100*da0073e9SAndroid Build Coastguard Worker
5101*da0073e9SAndroid Build Coastguard Worker        mymm = MyMM.apply
5102*da0073e9SAndroid Build Coastguard Worker
5103*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
5104*da0073e9SAndroid Build Coastguard Worker        # Puts one input tensor in a nested container.  y's contained Tensor won't receive a gradient,
5105*da0073e9SAndroid Build Coastguard Worker        # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
5106*da0073e9SAndroid Build Coastguard Worker        # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
5107*da0073e9SAndroid Build Coastguard Worker        y = (
5108*da0073e9SAndroid Build Coastguard Worker            0,
5109*da0073e9SAndroid Build Coastguard Worker            {
5110*da0073e9SAndroid Build Coastguard Worker                0: torch.randn(
5111*da0073e9SAndroid Build Coastguard Worker                    (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
5112*da0073e9SAndroid Build Coastguard Worker                )
5113*da0073e9SAndroid Build Coastguard Worker            },
5114*da0073e9SAndroid Build Coastguard Worker        )
5115*da0073e9SAndroid Build Coastguard Worker
5116*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda"):
5117*da0073e9SAndroid Build Coastguard Worker            output = mymm(x, y, torch.float32)
5118*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.dtype is torch.float32)
5119*da0073e9SAndroid Build Coastguard Worker            loss = output.sum()
5120*da0073e9SAndroid Build Coastguard Worker        loss.backward()
5121*da0073e9SAndroid Build Coastguard Worker
5122*da0073e9SAndroid Build Coastguard Worker        # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
5123*da0073e9SAndroid Build Coastguard Worker        output = mymm(x, y, torch.float16)
5124*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(output.dtype is torch.float16)
5125*da0073e9SAndroid Build Coastguard Worker        loss = output.sum()
5126*da0073e9SAndroid Build Coastguard Worker        loss.backward()
5127*da0073e9SAndroid Build Coastguard Worker
5128*da0073e9SAndroid Build Coastguard Worker    def test_autocast_custom_deprecated_warning(self):
5129*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5130*da0073e9SAndroid Build Coastguard Worker
5131*da0073e9SAndroid Build Coastguard Worker            class MyMM(torch.autograd.Function):
5132*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5133*da0073e9SAndroid Build Coastguard Worker                @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
5134*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x, y):
5135*da0073e9SAndroid Build Coastguard Worker                    ctx.save_for_backward(x, y)
5136*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_autocast_enabled())
5137*da0073e9SAndroid Build Coastguard Worker                    return x + y
5138*da0073e9SAndroid Build Coastguard Worker
5139*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5140*da0073e9SAndroid Build Coastguard Worker                @torch.cuda.amp.custom_bwd
5141*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, grad):
5142*da0073e9SAndroid Build Coastguard Worker                    _, _ = ctx.saved_tensors
5143*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_autocast_enabled())
5144*da0073e9SAndroid Build Coastguard Worker                    return grad, grad
5145*da0073e9SAndroid Build Coastguard Worker
5146*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(
5147*da0073e9SAndroid Build Coastguard Worker            str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
5148*da0073e9SAndroid Build Coastguard Worker        )
5149*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(
5150*da0073e9SAndroid Build Coastguard Worker            str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
5151*da0073e9SAndroid Build Coastguard Worker        )
5152*da0073e9SAndroid Build Coastguard Worker
5153*da0073e9SAndroid Build Coastguard Worker        mymm = MyMM.apply
5154*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, requires_grad=True)
5155*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 3, requires_grad=True)
5156*da0073e9SAndroid Build Coastguard Worker        with torch.amp.autocast("cuda"):
5157*da0073e9SAndroid Build Coastguard Worker            output = mymm(x, y)
5158*da0073e9SAndroid Build Coastguard Worker            loss = output.sum()
5159*da0073e9SAndroid Build Coastguard Worker        loss.backward()
5160*da0073e9SAndroid Build Coastguard Worker
5161*da0073e9SAndroid Build Coastguard Worker    def test_autocast_cat_jit(self):
5162*da0073e9SAndroid Build Coastguard Worker        # Reported at https://github.com/pytorch/pytorch/issues/38958
5163*da0073e9SAndroid Build Coastguard Worker
5164*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
5165*da0073e9SAndroid Build Coastguard Worker            def forward(self):
5166*da0073e9SAndroid Build Coastguard Worker                a = torch.randn(1)
5167*da0073e9SAndroid Build Coastguard Worker                b = torch.randn(1)
5168*da0073e9SAndroid Build Coastguard Worker                c = torch.cat((a, b), 0)
5169*da0073e9SAndroid Build Coastguard Worker                d = torch.stack([c, c], 0)
5170*da0073e9SAndroid Build Coastguard Worker                return d
5171*da0073e9SAndroid Build Coastguard Worker
5172*da0073e9SAndroid Build Coastguard Worker        # The JIT here doesn't really matter, we just need to call
5173*da0073e9SAndroid Build Coastguard Worker        # cat via the boxed API
5174*da0073e9SAndroid Build Coastguard Worker        model = Model()
5175*da0073e9SAndroid Build Coastguard Worker        model_jit_script = torch.jit.script(model)
5176*da0073e9SAndroid Build Coastguard Worker
5177*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda", enabled=True):
5178*da0073e9SAndroid Build Coastguard Worker            model()
5179*da0073e9SAndroid Build Coastguard Worker            model_jit_script()
5180*da0073e9SAndroid Build Coastguard Worker
5181*da0073e9SAndroid Build Coastguard Worker    # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
5182*da0073e9SAndroid Build Coastguard Worker    # so they get a dedicated test.
5183*da0073e9SAndroid Build Coastguard Worker    # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
5184*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
5185*da0073e9SAndroid Build Coastguard Worker    def test_autocast_rnn(self):
5186*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
5187*da0073e9SAndroid Build Coastguard Worker            # seq, batch, features, hidden size
5188*da0073e9SAndroid Build Coastguard Worker            clses = ("RNN", "GRU", "LSTM")
5189*da0073e9SAndroid Build Coastguard Worker            T, B, F, H = 3, 4, 5, 6
5190*da0073e9SAndroid Build Coastguard Worker            dtypes = (torch.float16, torch.float32)
5191*da0073e9SAndroid Build Coastguard Worker            input_layouts = ("seq_first", "batch_first", "packed")
5192*da0073e9SAndroid Build Coastguard Worker
5193*da0073e9SAndroid Build Coastguard Worker            for (
5194*da0073e9SAndroid Build Coastguard Worker                cls,
5195*da0073e9SAndroid Build Coastguard Worker                num_layers,
5196*da0073e9SAndroid Build Coastguard Worker                bias,
5197*da0073e9SAndroid Build Coastguard Worker                input_layout,
5198*da0073e9SAndroid Build Coastguard Worker                bidirectional,
5199*da0073e9SAndroid Build Coastguard Worker                try_nonpreflattened_weights,
5200*da0073e9SAndroid Build Coastguard Worker                input_dtype,
5201*da0073e9SAndroid Build Coastguard Worker                hidden_dtype,
5202*da0073e9SAndroid Build Coastguard Worker                weight_dtype,
5203*da0073e9SAndroid Build Coastguard Worker            ) in product(
5204*da0073e9SAndroid Build Coastguard Worker                clses,
5205*da0073e9SAndroid Build Coastguard Worker                (1, 2),
5206*da0073e9SAndroid Build Coastguard Worker                (True, False),
5207*da0073e9SAndroid Build Coastguard Worker                input_layouts,
5208*da0073e9SAndroid Build Coastguard Worker                (True, False),
5209*da0073e9SAndroid Build Coastguard Worker                (True, False),
5210*da0073e9SAndroid Build Coastguard Worker                dtypes,
5211*da0073e9SAndroid Build Coastguard Worker                dtypes,
5212*da0073e9SAndroid Build Coastguard Worker                dtypes,
5213*da0073e9SAndroid Build Coastguard Worker            ):
5214*da0073e9SAndroid Build Coastguard Worker                if input_layout == "seq_first":
5215*da0073e9SAndroid Build Coastguard Worker                    batch_first = False
5216*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
5217*da0073e9SAndroid Build Coastguard Worker                elif input_layout == "batch_first":
5218*da0073e9SAndroid Build Coastguard Worker                    batch_first = True
5219*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
5220*da0073e9SAndroid Build Coastguard Worker                elif input_layout == "packed":
5221*da0073e9SAndroid Build Coastguard Worker                    batch_first = False
5222*da0073e9SAndroid Build Coastguard Worker                    x = torch.nn.utils.rnn.pack_padded_sequence(
5223*da0073e9SAndroid Build Coastguard Worker                        torch.randn((T, B, F), device="cuda", dtype=input_dtype),
5224*da0073e9SAndroid Build Coastguard Worker                        lengths=(3, 2, 1, 3),
5225*da0073e9SAndroid Build Coastguard Worker                        enforce_sorted=False,
5226*da0073e9SAndroid Build Coastguard Worker                    )
5227*da0073e9SAndroid Build Coastguard Worker
5228*da0073e9SAndroid Build Coastguard Worker                rnn = (
5229*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.nn, cls)(
5230*da0073e9SAndroid Build Coastguard Worker                        F,
5231*da0073e9SAndroid Build Coastguard Worker                        H,
5232*da0073e9SAndroid Build Coastguard Worker                        num_layers=num_layers,
5233*da0073e9SAndroid Build Coastguard Worker                        bidirectional=bidirectional,
5234*da0073e9SAndroid Build Coastguard Worker                        bias=bias,
5235*da0073e9SAndroid Build Coastguard Worker                        batch_first=batch_first,
5236*da0073e9SAndroid Build Coastguard Worker                    )
5237*da0073e9SAndroid Build Coastguard Worker                    .cuda()
5238*da0073e9SAndroid Build Coastguard Worker                    .to(dtype=weight_dtype)
5239*da0073e9SAndroid Build Coastguard Worker                )
5240*da0073e9SAndroid Build Coastguard Worker
5241*da0073e9SAndroid Build Coastguard Worker                if try_nonpreflattened_weights:
5242*da0073e9SAndroid Build Coastguard Worker                    for p in rnn.parameters():
5243*da0073e9SAndroid Build Coastguard Worker                        with torch.no_grad():
5244*da0073e9SAndroid Build Coastguard Worker                            p.set_(p.clone())
5245*da0073e9SAndroid Build Coastguard Worker
5246*da0073e9SAndroid Build Coastguard Worker                h = torch.randn(
5247*da0073e9SAndroid Build Coastguard Worker                    (num_layers * (2 if bidirectional else 1), B, H),
5248*da0073e9SAndroid Build Coastguard Worker                    device="cuda",
5249*da0073e9SAndroid Build Coastguard Worker                    dtype=hidden_dtype,
5250*da0073e9SAndroid Build Coastguard Worker                )
5251*da0073e9SAndroid Build Coastguard Worker                if cls == "LSTM":
5252*da0073e9SAndroid Build Coastguard Worker                    c = torch.randn(
5253*da0073e9SAndroid Build Coastguard Worker                        (num_layers * (2 if bidirectional else 1), B, H),
5254*da0073e9SAndroid Build Coastguard Worker                        device="cuda",
5255*da0073e9SAndroid Build Coastguard Worker                        dtype=hidden_dtype,
5256*da0073e9SAndroid Build Coastguard Worker                    )
5257*da0073e9SAndroid Build Coastguard Worker                    h = (h, c)
5258*da0073e9SAndroid Build Coastguard Worker
5259*da0073e9SAndroid Build Coastguard Worker                with torch.autocast("cuda"):
5260*da0073e9SAndroid Build Coastguard Worker                    out, h_out = rnn(x, h)
5261*da0073e9SAndroid Build Coastguard Worker                out = out.data if input_layout == "packed" else out
5262*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.dtype, torch.float16)
5263*da0073e9SAndroid Build Coastguard Worker                # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed.  This check can't guarantee
5264*da0073e9SAndroid Build Coastguard Worker                # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
5265*da0073e9SAndroid Build Coastguard Worker                # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
5266*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5267*da0073e9SAndroid Build Coastguard Worker                    out.grad_fn.name(),
5268*da0073e9SAndroid Build Coastguard Worker                    "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
5269*da0073e9SAndroid Build Coastguard Worker                )
5270*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
5271*da0073e9SAndroid Build Coastguard Worker                grads = [p.grad.clone() for p in rnn.parameters()]
5272*da0073e9SAndroid Build Coastguard Worker
5273*da0073e9SAndroid Build Coastguard Worker                rnn.zero_grad()
5274*da0073e9SAndroid Build Coastguard Worker
5275*da0073e9SAndroid Build Coastguard Worker                if cls == "LSTM":
5276*da0073e9SAndroid Build Coastguard Worker                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
5277*da0073e9SAndroid Build Coastguard Worker                        x.half(), (h[0].half(), h[1].half())
5278*da0073e9SAndroid Build Coastguard Worker                    )
5279*da0073e9SAndroid Build Coastguard Worker                else:
5280*da0073e9SAndroid Build Coastguard Worker                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
5281*da0073e9SAndroid Build Coastguard Worker                        x.half(), h.half()
5282*da0073e9SAndroid Build Coastguard Worker                    )
5283*da0073e9SAndroid Build Coastguard Worker                out_control = (
5284*da0073e9SAndroid Build Coastguard Worker                    out_control.data if input_layout == "packed" else out_control
5285*da0073e9SAndroid Build Coastguard Worker                )
5286*da0073e9SAndroid Build Coastguard Worker                out_control.sum().backward()
5287*da0073e9SAndroid Build Coastguard Worker                grads_control = [p.grad.clone() for p in rnn.parameters()]
5288*da0073e9SAndroid Build Coastguard Worker
5289*da0073e9SAndroid Build Coastguard Worker                # Compares with default tolerances, even for FP16 execution.  Barring nondeterminism,
5290*da0073e9SAndroid Build Coastguard Worker                # autocast and control results should be bitwise identical.
5291*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, out_control)
5292*da0073e9SAndroid Build Coastguard Worker
5293*da0073e9SAndroid Build Coastguard Worker                if cls == "LSTM":
5294*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
5295*da0073e9SAndroid Build Coastguard Worker                        h_out[0].dtype is torch.float16
5296*da0073e9SAndroid Build Coastguard Worker                        and h_out[1].dtype is torch.float16
5297*da0073e9SAndroid Build Coastguard Worker                    )
5298*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(h_out[0], h_out_control[0])
5299*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(h_out[1], h_out_control[1])
5300*da0073e9SAndroid Build Coastguard Worker                else:
5301*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(h_out.dtype, torch.float16)
5302*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(h_out, h_out_control)
5303*da0073e9SAndroid Build Coastguard Worker                for grad, grad_control in zip(grads, grads_control):
5304*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad.half(), grad_control)
5305*da0073e9SAndroid Build Coastguard Worker
5306*da0073e9SAndroid Build Coastguard Worker    def test_autocast_cache_leak(self):
5307*da0073e9SAndroid Build Coastguard Worker        # Reported at https://github.com/pytorch/pytorch/issues/48049
5308*da0073e9SAndroid Build Coastguard Worker        # Test is used to check, if autocast recaches the same parameters
5309*da0073e9SAndroid Build Coastguard Worker        # when executed in a `torch.no_grad()` block.
5310*da0073e9SAndroid Build Coastguard Worker
5311*da0073e9SAndroid Build Coastguard Worker        linear = torch.nn.Linear(10, 10).to("cuda")
5312*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(1, 10, device="cuda")
5313*da0073e9SAndroid Build Coastguard Worker
5314*da0073e9SAndroid Build Coastguard Worker        with torch.autocast("cuda"):
5315*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
5316*da0073e9SAndroid Build Coastguard Worker                out = linear(data)
5317*da0073e9SAndroid Build Coastguard Worker                first_iter_mem = torch.cuda.memory_allocated()
5318*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
5319*da0073e9SAndroid Build Coastguard Worker                    out = linear(data)
5320*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
5321*da0073e9SAndroid Build Coastguard Worker
5322*da0073e9SAndroid Build Coastguard Worker    def test_autocast_checkpointing(self):
5323*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Sequential(
5324*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
5325*da0073e9SAndroid Build Coastguard Worker        ).cuda()
5326*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(
5327*da0073e9SAndroid Build Coastguard Worker            (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
5328*da0073e9SAndroid Build Coastguard Worker        )
5329*da0073e9SAndroid Build Coastguard Worker        for reentrant in (True, False):
5330*da0073e9SAndroid Build Coastguard Worker            with torch.autocast("cuda"):
5331*da0073e9SAndroid Build Coastguard Worker                output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
5332*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.requires_grad)
5333*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.dtype is torch.float16)
5334*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
5335*da0073e9SAndroid Build Coastguard Worker
5336*da0073e9SAndroid Build Coastguard Worker    def test_cuda_autocast_deprecated_warning(self):
5337*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
5338*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
5339*da0073e9SAndroid Build Coastguard Worker            r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
5340*da0073e9SAndroid Build Coastguard Worker        ):
5341*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.amp.autocast():
5342*da0073e9SAndroid Build Coastguard Worker                _ = torch.ones(10)
5343*da0073e9SAndroid Build Coastguard Worker
5344*da0073e9SAndroid Build Coastguard Worker
5345*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCuda)
5346*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCudaMallocAsync)
5347*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCudaOptims, globals())
5348*da0073e9SAndroid Build Coastguard Worker
5349*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
5350*da0073e9SAndroid Build Coastguard Worker    run_tests()
5351