xref: /aosp_15_r20/external/pytorch/test/test_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2
3import contextlib
4import ctypes
5import gc
6import json
7import os
8import pickle
9import random
10import subprocess
11import sys
12import tempfile
13import threading
14import unittest
15import warnings
16from copy import deepcopy
17from itertools import product
18from random import randint
19
20import psutil
21
22import torch
23import torch.cuda
24from torch import inf, nan
25from torch.cuda._memory_viz import (
26    _profile_to_snapshot,
27    profile_plot,
28    segment_plot,
29    trace_plot,
30)
31from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
32from torch.testing._internal.common_cuda import (
33    _create_scaling_case,
34    _get_torch_cuda_version,
35    TEST_CUDNN,
36    TEST_MULTIGPU,
37)
38from torch.testing._internal.common_device_type import (
39    instantiate_device_type_tests,
40    onlyCUDA,
41    onlyNativeDeviceTypes,
42)
43from torch.testing._internal.common_optimizers import (
44    _get_optim_inputs_including_global_cliquey_kwargs,
45    optim_db,
46    optims,
47    TensorTracker,
48)
49from torch.testing._internal.common_utils import (
50    EXPANDABLE_SEGMENTS,
51    freeze_rng_state,
52    gcIfJetson,
53    get_cycles_per_ms,
54    instantiate_parametrized_tests,
55    IS_ARM64,
56    IS_FBCODE,
57    IS_JETSON,
58    IS_LINUX,
59    IS_SANDCASTLE,
60    IS_WINDOWS,
61    load_tests,
62    NO_MULTIPROCESSING_SPAWN,
63    parametrize,
64    run_tests,
65    serialTest,
66    skipCUDAMemoryLeakCheckIf,
67    skipCUDANonDefaultStreamIf,
68    skipIfRocm,
69    slowTest,
70    subtest,
71    TemporaryFileName,
72    TEST_CUDA,
73    TEST_CUDA_GRAPH,
74    TEST_NUMPY,
75    TEST_WITH_ROCM,
76    TestCase,
77)
78from torch.utils.checkpoint import checkpoint_sequential
79from torch.utils.viz._cycles import observe_tensor_cycles
80
81
82# load_tests from common_utils is used to automatically filter tests for
83# sharding on sandcastle. This line silences flake warnings
84load_tests = load_tests
85
86try:
87    import torchvision.models  # noqa: F401
88    from torchvision.models import resnet18  # noqa: F401
89
90    HAS_TORCHVISION = True
91except ImportError:
92    HAS_TORCHVISION = False
93skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
94
95TEST_CUDAMALLOCASYNC = TEST_CUDA and (
96    torch.cuda.get_allocator_backend() == "cudaMallocAsync"
97)
98TEST_LARGE_TENSOR = TEST_CUDA
99TEST_MEDIUM_TENSOR = TEST_CUDA
100TEST_BF16 = False
101TEST_PYNVML = not torch.cuda._HAS_PYNVML
102if TEST_CUDA:
103    TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
104    TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9
105    TEST_BF16 = torch.cuda.is_bf16_supported()
106
107_cycles_per_ms = None
108
109
110@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
111@torch.testing._internal.common_utils.markDynamoStrictTest
112class TestCuda(TestCase):
113    _do_cuda_memory_leak_check = True
114    _do_cuda_non_default_stream = True
115    FIFTY_MIL_CYCLES = 50000000
116
117    def setUp(self):
118        super().setUp()
119
120    def tearDown(self):
121        super().tearDown()
122
123    @property
124    def expandable_segments(self):
125        return EXPANDABLE_SEGMENTS
126
127    def test_pinned_memory_with_cudaregister(self):
128        torch.cuda.memory._set_allocator_settings(
129            "pinned_use_cuda_host_register:True,pinned_num_register_threads:8"
130        )
131        t = torch.ones(20)
132        self.assertFalse(t.is_pinned())
133        try:
134            pinned_t = torch.ones(1 << 21).pin_memory()
135            self.assertTrue(pinned_t.is_pinned())
136            pinned_t = torch.ones(1 << 24).pin_memory()
137            self.assertTrue(pinned_t.is_pinned())
138        except RuntimeError as e:
139            # Some GPUs don't support same address space on host and device side
140            pass
141
142    def test_pinned_memory_with_cudaregister_multithread(self):
143        num_threads = 4
144        threads = [
145            threading.Thread(target=self.test_pinned_memory_with_cudaregister)
146            for t in range(num_threads)
147        ]
148        for thread in threads:
149            thread.start()
150        for thread in threads:
151            thread.join()
152
153    def test_pinned_memory_empty_cache(self):
154        for alloc_settings in (True, False):
155            torch.cuda.memory._set_allocator_settings(
156                f"pinned_use_cuda_host_register:{alloc_settings}"
157            )
158            try:
159                t = torch.ones(1024 * 1024, pin_memory=True)
160                self.assertTrue(t.is_pinned())
161                del t
162                torch._C._host_emptyCache()
163            except RuntimeError as e:
164                # Some GPUs don't support same address space on host and device side
165                pass
166
167    def test_cudart_register(self):
168        t = torch.ones(20)
169        self.assertFalse(t.is_pinned())
170        cudart = torch.cuda.cudart()
171        r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
172        self.assertEqual(r, 0)
173        self.assertTrue(t.is_pinned())
174        r = cudart.cudaHostUnregister(t.data_ptr())
175        self.assertEqual(r, 0)
176        self.assertFalse(t.is_pinned())
177
178    def test_memory_allocation(self):
179        gc.collect()
180        torch.cuda.empty_cache()
181        mem = None
182        size = 1
183        prev = 0
184        try:
185            prev = torch.cuda.memory_allocated()
186            mem = torch.cuda.caching_allocator_alloc(size)
187            self.assertGreater(torch.cuda.memory_allocated(), prev)
188        finally:
189            if mem is not None:
190                torch.cuda.caching_allocator_delete(mem)
191                self.assertEqual(torch.cuda.memory_allocated(), prev)
192
193    def test_check_error(self):
194        # Assert this call doesn't raise.
195        torch.cuda.check_error(0)
196
197        with self.assertRaisesRegex(
198            torch.cuda.CudaError, "out of memory|hipErrorOutOfMemory"
199        ):
200            torch.cuda.check_error(2)
201
202    def test_cuda_get_device_name(self):
203        # Testing the behaviour with None as an argument
204        current_device = torch.cuda.current_device()
205        current_device_name = torch.cuda.get_device_name(current_device)
206        device_name_None = torch.cuda.get_device_name(None)
207        self.assertEqual(current_device_name, device_name_None)
208
209        # Testing the behaviour for No argument
210        device_name_no_argument = torch.cuda.get_device_name()
211        self.assertEqual(current_device_name, device_name_no_argument)
212
213    def test_cuda_get_device_capability(self):
214        # Testing the behaviour with None as an argument
215        current_device = torch.cuda.current_device()
216        current_device_capability = torch.cuda.get_device_capability(current_device)
217        device_capability_None = torch.cuda.get_device_capability(None)
218        self.assertEqual(current_device_capability, device_capability_None)
219
220        # Testing the behaviour for No argument
221        device_capability_no_argument = torch.cuda.get_device_capability()
222        self.assertEqual(current_device_capability, device_capability_no_argument)
223
224    def test_out_of_memory(self):
225        tensor = torch.zeros(1024, device="cuda")
226
227        oom_regex = (
228            "would exceed allowed memory"
229            if TEST_CUDAMALLOCASYNC
230            else "Tried to allocate 800000000.00 GiB"
231        )
232        with self.assertRaisesRegex(RuntimeError, oom_regex):
233            torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="cuda")
234
235        with self.assertRaisesRegex(
236            RuntimeError, "Tried to allocate more than 1EB memory"
237        ):
238            torch.empty(
239                1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="cuda"
240            )
241
242        # ensure out of memory error doesn't disturb subsequent kernel
243        tensor.fill_(1)
244        self.assertTrue((tensor == 1).all())
245
246    @unittest.skipIf(
247        TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)"
248    )
249    @serialTest()
250    def test_out_of_memory_retry(self):
251        torch.cuda.empty_cache()
252        total_memory = torch.cuda.get_device_properties(0).total_memory
253        oom_regex = (
254            "would exceed allowed memory"
255            if TEST_CUDAMALLOCASYNC
256            else "Tried to allocate"
257        )
258        size = int(total_memory * 0.5)
259        a = torch.empty(size, dtype=torch.int8, device="cuda")
260        with self.assertRaisesRegex(RuntimeError, oom_regex):
261            b = torch.empty(size, dtype=torch.int8, device="cuda")
262        del a
263        b = torch.empty(size, dtype=torch.int8, device="cuda")
264        del b
265        # We used a lot of memory here, clean up so we don't affect other tests too much
266        torch.cuda.empty_cache()
267        torch.cuda.reset_peak_memory_stats()
268
269    @serialTest()
270    def test_set_per_process_memory_fraction(self):
271        # test invalid fraction value.
272        with self.assertRaisesRegex(TypeError, "Invalid type"):
273            torch.cuda.set_per_process_memory_fraction(1)
274        with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
275            torch.cuda.set_per_process_memory_fraction(-0.1)
276        with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
277            torch.cuda.set_per_process_memory_fraction(2.0)
278
279        tensor = torch.zeros(1024, device="cuda")
280        torch.cuda.empty_cache()
281        total_memory = torch.cuda.get_device_properties(0).total_memory
282        torch.cuda.set_per_process_memory_fraction(0.5, 0)
283
284        # test 0.499 allocation is ok.
285        application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved()
286        tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda")
287        del tmp_tensor
288        torch.cuda.empty_cache()
289
290        application = int(total_memory * 0.5)
291        # it will get OOM when try to allocate more than half memory.
292        oom_regex = (
293            "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
294        )
295        with self.assertRaisesRegex(RuntimeError, oom_regex):
296            torch.empty(application, dtype=torch.int8, device="cuda")
297
298        # ensure out of memory error doesn't disturb subsequent kernel
299        tensor.fill_(1)
300        self.assertTrue((tensor == 1).all())
301
302    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "uuid attribute not yet available")
303    def test_uuid(self):
304        uuid = torch.cuda.get_device_properties(0).uuid
305        self.assertEqual(len(str(uuid)), 36)  # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
306        self.assertEqual(len(uuid.bytes), 16)
307
308    def test_copy_non_blocking(self):
309        def _test_copy_non_blocking(a, b):
310            event = torch.cuda.Event()
311            a.copy_(b, non_blocking=True)
312            event.record()
313            event.synchronize()
314            self.assertEqual(a, b)
315
316        # 10MB copies
317        x = torch.ones(10000000, dtype=torch.uint8).cuda()
318        y = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
319        _test_copy_non_blocking(x, y)
320
321        x = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
322        y = torch.ones(10000000, dtype=torch.uint8).cuda()
323        _test_copy_non_blocking(x, y)
324
325        # Test the case where the pinned data_ptr is not equal to the storage data_ptr.
326        x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
327        x = x_base[1:]
328        self.assertTrue(x.is_pinned())
329        self.assertTrue(x_base.is_pinned())
330        self.assertNotEqual(x_base.data_ptr(), x.data_ptr())
331        self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr())
332        y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda()
333        _test_copy_non_blocking(x, y)
334
335    def test_copy_non_blocking_type_conversion(self):
336        a = torch.ones(1, device="cuda")
337        b = torch.zeros(1, device="cpu", pin_memory=True)
338        c = torch.empty(1, device="cuda", dtype=torch.long)
339        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
340        b.copy_(a, non_blocking=True)
341        c.copy_(b, non_blocking=True)
342        self.assertEqual(a, c, exact_dtype=False)
343
344    @serialTest()
345    def test_to_non_blocking(self):
346        stream = torch.cuda.current_stream()
347
348        def _test_to_non_blocking(a, non_blocking, dst):
349            torch.cuda.synchronize()
350            # Pushes an 0.1 second spin to stream so if the copy is non blocking,
351            # stream will almost surely be active when we query().
352            torch.cuda._sleep(int(100 * get_cycles_per_ms()))
353            b = a.to(device=dst, non_blocking=non_blocking)
354            self.assertEqual(stream.query(), not non_blocking)
355            stream.synchronize()
356            self.assertEqual(a, b)
357            self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu"))
358
359        for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)):
360            # Creates source on the opposite device from destination.
361            src = torch.randn(
362                1000000,
363                device="cuda" if dst == "cpu" else "cpu",
364                pin_memory=True if dst == "cuda" else False,
365            )
366            _test_to_non_blocking(src, try_non_blocking, dst)
367
368    def test_to_cpu_blocking_by_default(self):
369        src = torch.randn(1000000, device="cuda")
370        torch.cuda.synchronize()
371        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
372        dst = src.to(device="cpu")
373        self.assertEqual(torch.cuda.current_stream().query(), True)
374        self.assertEqual(src, dst)
375        self.assertFalse(dst.is_pinned())
376
377    def test_serialization_array_with_storage(self):
378        x = torch.randn(5, 5).cuda()
379        y = torch.IntTensor(2, 5).fill_(0).cuda()
380        q = [x, y, x, y.storage()]
381        with tempfile.NamedTemporaryFile() as f:
382            torch.save(q, f)
383            f.seek(0)
384            q_copy = torch.load(f)
385        self.assertEqual(q_copy, q, atol=0, rtol=0)
386        q_copy[0].fill_(5)
387        self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
388        self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
389        self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
390        self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
391        self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
392        self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
393        q_copy[1].fill_(10)
394        self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
395
396    @unittest.skipIf(
397        TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
398    )
399    @unittest.skipIf(
400        _get_torch_cuda_version() >= (12, 2),
401        "skipped as explicit workspace allocation is removed",
402    )
403    def test_cublas_workspace_explicit_allocation(self):
404        a = torch.randn(7, 7, device="cuda", requires_grad=False)
405        default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024  # :4096:2:16:8
406        # different size (32 MiB) expected on Hopper GPU
407        if torch.cuda.get_device_capability() == (9, 0):
408            default_workspace_size = 4096 * 8 * 1024
409
410        def check_workspace_size(inp):
411            torch._C._cuda_clearCublasWorkspaces()
412            start = torch.cuda.memory_stats()["active_bytes.all.allocated"]
413            with torch.no_grad():
414                torch.matmul(inp, inp)
415            finish = torch.cuda.memory_stats()["active_bytes.all.allocated"]
416            return finish - start
417
418        # check default
419        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
420        self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
421
422        # check default with bad user config
423        os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1"
424        self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
425
426        # check valid config
427        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32"
428        self.assertTrue(abs(check_workspace_size(a) - (3072 * 1024)) < 524288)
429
430        torch._C._cuda_clearCublasWorkspaces()
431
432    def test_cublas_allow_tf32_get_set(self):
433        skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
434            os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
435        )
436        if skip_tf32_cublas:
437            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
438            return
439
440        orig = torch.backends.cuda.matmul.allow_tf32
441        self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
442        torch.backends.cuda.matmul.allow_tf32 = not orig
443        self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
444        torch.backends.cuda.matmul.allow_tf32 = orig
445
446    def test_float32_matmul_precision_get_set(self):
447        orig = torch.get_float32_matmul_precision()
448        skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
449            os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
450        )
451        # this is really just checking that the environment variable is respected during testing
452        # and not overwritten by another function that doesn't revert it to the intitial value
453        if not skip_tf32_cublas:
454            self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
455            self.assertEqual(torch.get_float32_matmul_precision(), "highest")
456        else:
457            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
458        for p in ("medium", "high"):
459            torch.set_float32_matmul_precision(p)
460            self.assertEqual(torch.get_float32_matmul_precision(), p)
461            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
462        torch.set_float32_matmul_precision("highest")
463        self.assertEqual(torch.get_float32_matmul_precision(), "highest")
464        self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
465        torch.set_float32_matmul_precision(orig)
466
467    def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
468        orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
469        self.assertEqual(
470            torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig
471        )
472        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
473        self.assertEqual(
474            torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig
475        )
476        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
477
478    def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
479        orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
480        self.assertEqual(
481            torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig
482        )
483        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
484        self.assertEqual(
485            torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig
486        )
487        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
488
489    def test_cudnn_allow_tf32_get_set(self):
490        with torch.backends.cudnn.flags(
491            enabled=None, benchmark=None, deterministic=None, allow_tf32=False
492        ):
493            self.assertFalse(torch.backends.cudnn.allow_tf32)
494        with torch.backends.cudnn.flags(
495            enabled=None, benchmark=None, deterministic=None, allow_tf32=True
496        ):
497            self.assertTrue(torch.backends.cudnn.allow_tf32)
498
499    def test_type_conversions(self):
500        x = torch.randn(5, 5)
501        self.assertIsInstance(x.float(), torch.FloatTensor)
502        self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor)
503        self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor)
504        self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor)
505        self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor)
506
507        y = x.storage()
508        self.assertIsInstance(y.float(), torch.FloatStorage)
509        self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage)
510        self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage)
511        self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage)
512        self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage)
513
514    @unittest.skip("was disabled due to not enough memory, but actually it always fail")
515    def test_arithmetic_large_tensor(self):
516        x = torch.empty(2**30, device="cuda")
517
518        x.fill_(1)
519        self.assertEqual(x.sum(), 2**30)
520
521        x += 1
522        self.assertEqual(x.sum(), 2**31)
523
524        x.fill_(1)
525        x -= 0.5
526        self.assertEqual(x.sum(), 2**29)
527
528        x.fill_(1)
529        x *= 2
530        self.assertEqual(x.sum(), 2**31)
531
532        x.fill_(1)
533        x /= 2
534        self.assertEqual(x.sum(), 2**29)
535
536    def test_gather_bool(self):
537        t = torch.tensor([[False, True], [True, True]], device="cuda")
538        self.assertEqual(
539            torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device="cuda")),
540            torch.tensor([[False, False], [True, True]], device="cuda"),
541        )
542
543    def test_torch_manual_seed_seeds_cuda_devices(self):
544        with freeze_rng_state():
545            x = torch.zeros(4, 4).float().cuda()
546            torch.manual_seed(2)
547            self.assertEqual(torch.cuda.initial_seed(), 2)
548            x.uniform_()
549            torch.manual_seed(2)
550            y = x.clone().uniform_()
551            self.assertEqual(x, y)
552            self.assertEqual(torch.cuda.initial_seed(), 2)
553
554    def test_manual_seed(self):
555        with freeze_rng_state():
556            x = torch.zeros(4, 4).float().cuda()
557            torch.cuda.manual_seed(2)
558            self.assertEqual(torch.cuda.initial_seed(), 2)
559            x.uniform_()
560            a = torch.bernoulli(torch.full_like(x, 0.5))
561            torch.cuda.manual_seed(2)
562            y = x.clone().uniform_()
563            b = torch.bernoulli(torch.full_like(x, 0.5))
564            self.assertEqual(x, y)
565            self.assertEqual(a, b)
566            self.assertEqual(torch.cuda.initial_seed(), 2)
567
568    def test_specify_improper_device_name(self):
569        import os
570
571        fname = "tempfile.pt"
572        try:
573            with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
574                torch.save(
575                    [torch.nn.Parameter(torch.randn(10, 10))],
576                    fname,
577                    _use_new_zipfile_serialization=True,
578                )
579                torch.load(fname, "cuda0")
580        finally:
581            if os.path.exists(fname):
582                os.remove(fname)
583
584    def test_get_device_index(self):
585        from torch.cuda._utils import _get_device_index
586
587        with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
588            _get_device_index("cuda0", optional=True)
589
590        with self.assertRaisesRegex(ValueError, "Expected a cuda device"):
591            cpu_device = torch.device("cpu")
592            _get_device_index(cpu_device, optional=True)
593
594    def test_serialization_array_with_empty(self):
595        x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()]
596        with tempfile.NamedTemporaryFile() as f:
597            torch.save(x, f)
598            f.seek(0)
599            x_copy = torch.load(f)
600        for original, copy in zip(x, x_copy):
601            self.assertEqual(copy, original)
602            self.assertIs(type(copy), type(original))
603            self.assertEqual(copy.get_device(), original.get_device())
604
605    @skipCUDANonDefaultStreamIf(True)
606    def test_streams(self):
607        default_stream = torch.cuda.current_stream()
608        user_stream = torch.cuda.Stream()
609        self.assertEqual(torch.cuda.current_stream(), default_stream)
610        self.assertNotEqual(default_stream, user_stream)
611        self.assertEqual(default_stream.cuda_stream, 0)
612        self.assertNotEqual(user_stream.cuda_stream, 0)
613        with torch.cuda.stream(user_stream):
614            self.assertEqual(torch.cuda.current_stream(), user_stream)
615        self.assertTrue(user_stream.query())
616        tensor1 = torch.ByteTensor(5).pin_memory()
617        tensor2 = tensor1.cuda(non_blocking=True) + 1
618        default_stream.synchronize()
619        self.assertTrue(default_stream.query())
620
621    def test_stream_event_repr(self):
622        s = torch.cuda.current_stream()
623        self.assertTrue("torch.cuda.Stream" in s.__repr__())
624        e = torch.cuda.Event()
625        self.assertTrue("torch.cuda.Event" in e.__repr__())
626        s.record_event(e)
627        self.assertTrue("torch.cuda.Event" in e.__repr__())
628
629    def test_events(self):
630        stream = torch.cuda.current_stream()
631        event = torch.cuda.Event(enable_timing=True)
632        self.assertTrue(event.query())
633        start_event = torch.cuda.Event(enable_timing=True)
634        stream.record_event(start_event)
635        torch.cuda._sleep(int(50 * get_cycles_per_ms()))
636        stream.record_event(event)
637        self.assertFalse(event.query())
638        event.synchronize()
639        self.assertTrue(event.query())
640        self.assertGreater(start_event.elapsed_time(event), 0)
641
642    def test_generic_stream_event(self):
643        stream = torch.Stream("cuda")
644        self.assertEqual(stream.device_index, torch.cuda.current_device())
645        cuda_stream = torch.cuda.Stream(
646            stream_id=stream.stream_id,
647            device_index=stream.device_index,
648            device_type=stream.device_type,
649        )
650        self.assertEqual(stream.stream_id, cuda_stream.stream_id)
651        self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
652
653        event1 = torch.Event("cuda", enable_timing=True)
654        event2 = torch.Event("cuda", enable_timing=True)
655        self.assertEqual(event1.event_id, 0)
656        a = torch.randn(1000)
657        b = torch.randn(1000)
658        with torch.cuda.stream(cuda_stream):
659            a_cuda = a.to("cuda", non_blocking=True)
660            b_cuda = b.to("cuda", non_blocking=True)
661            self.assertEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
662        event1.record(stream)
663        event1.synchronize()
664        self.assertTrue(event1.query())
665        c_cuda = a_cuda + b_cuda
666        event2.record()
667        event2.synchronize()
668        self.assertTrue(event2.query())
669        self.assertNotEqual(event1.event_id, event2.event_id)
670        self.assertEqual(c_cuda.cpu(), a + b)
671        self.assertTrue(event1.elapsed_time(event2) > 0)
672
673    def test_record_stream(self):
674        cycles_per_ms = get_cycles_per_ms()
675
676        t = torch.FloatTensor([1, 2, 3, 4]).pin_memory()
677        result = torch.cuda.FloatTensor(t.size())
678        stream = torch.cuda.Stream()
679        ptr = [None]
680
681        # Performs the CPU->GPU copy in a background stream
682        def perform_copy():
683            with torch.cuda.stream(stream):
684                tmp = t.cuda(non_blocking=True)
685                ptr[0] = tmp.data_ptr()
686            torch.cuda.current_stream().wait_stream(stream)
687            tmp.record_stream(torch.cuda.current_stream())
688            torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the copy
689            result.copy_(tmp)
690
691        perform_copy()
692        with torch.cuda.stream(stream):
693            tmp2 = torch.cuda.FloatTensor(t.size())
694            tmp2.zero_()
695            self.assertNotEqual(
696                tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon"
697            )
698
699        self.assertEqual(result.tolist(), [1, 2, 3, 4])
700
701        if not TEST_CUDAMALLOCASYNC:
702            # In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused
703            # in that side stream after result.copy_(tmp) in the main stream finishes.
704            torch.cuda.current_stream().synchronize()
705            with torch.cuda.stream(stream):
706                tmp3 = torch.cuda.FloatTensor(t.size())
707                self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used")
708
709    def test_record_stream_on_shifted_view(self):
710        # See issue #27366
711
712        # This test detects unexpected block reallocation. For reliable test,
713        # the stream to allocate tensors is isolated. The allocator will not
714        # reuse free blocks which were allocated from another stream.
715        stream_alloc = torch.cuda.Stream()
716        with torch.cuda.stream(stream_alloc):
717            base = torch.cuda.FloatTensor([10, 10])
718
719        # Record another stream on a shifted view tensor.
720        view = base[5:]
721        assert view.storage_offset() > 0
722
723        stream_record = torch.cuda.Stream()
724        with torch.cuda.stream(stream_record):
725            torch.cuda._sleep(int(50 * get_cycles_per_ms()))
726
727        view.record_stream(stream_record)
728
729        # Delete those tensors to make the block free soon.
730        data_ptr = base.data_ptr()
731        del base, view
732
733        # A new tensor should not be allocated to the block above.
734        stream_alloc.synchronize()
735
736        with torch.cuda.stream(stream_alloc):
737            try_realloc = torch.cuda.FloatTensor([10, 10])
738
739        self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
740
741    def test_noncontiguous_pinned_memory(self):
742        # See issue #3266
743        x = torch.arange(0, 10).view((2, 5))
744        self.assertEqual(x.t(), x.t().pin_memory())
745
746    def test_caching_pinned_memory(self):
747        cycles_per_ms = get_cycles_per_ms()
748
749        # check that allocations are re-used after deletion
750        t = torch.FloatTensor([1]).pin_memory()
751        ptr = t.data_ptr()
752        del t
753        t = torch.FloatTensor([1]).pin_memory()
754        self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused")
755
756        # check that the allocation is not re-used if it's in-use by a copy
757        gpu_tensor = torch.cuda.FloatTensor([0])
758        torch.cuda._sleep(int(1000 * cycles_per_ms))  # delay the copy by 1s
759        gpu_tensor.copy_(t, non_blocking=True)
760        del t
761        t = torch.FloatTensor([1]).pin_memory()
762        self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon")
763        self.assertEqual(list(gpu_tensor), [1])
764
765    def test_caching_allocator_record_stream_oom(self):
766        """allocations delayed by a record_stream call should still be freed on
767        an out-of-memory in cuda_malloc_retry. see issue #19219"""
768        stream = torch.cuda.Stream()
769
770        with torch.cuda.stream(stream):
771            y = torch.zeros(40 * 1024 * 1024, device="cuda")
772
773        for _ in range(100):
774            x = torch.empty(40 * 1024 * 1024, device="cuda")
775            with torch.cuda.stream(stream):
776                y += x
777            # delays re-use of `x` until after all operations in `stream`
778            x.record_stream(stream)
779            del x
780
781        # we've made a mess by allocating up to the device capacity. free any
782        # cached blocks in case it affects future tests.
783        torch.cuda.empty_cache()
784
785    # Tests for historic illegal memory access, see #17040.
786    def test_reduction_gpu_memory_accessing(self):
787        x = torch.ones(512, 8, dtype=torch.float32, device="cuda")
788        torch.sum(x, 0)
789
790    def test_sum_fp16(self):
791        x = torch.zeros(10, device="cuda", dtype=torch.float16)
792        self.assertEqual(x.sum(), 0)
793
794        x = torch.ones(65504, device="cuda", dtype=torch.float16)
795        self.assertEqual(x.sum(), 65504)
796        self.assertEqual(x.sum(dtype=torch.float32), 65504)
797
798        x = torch.ones(65536, device="cuda", dtype=torch.float16)
799        self.assertEqual(x.sum(dtype=torch.float32), 65536)
800
801        a = torch.zeros(1203611).bernoulli_(0.0005)
802        x = a.to(device="cuda", dtype=torch.float16)
803        self.assertEqual(x.sum().item(), a.sum().item())
804
805        a = torch.zeros(100, 121, 80).bernoulli_(0.0005)
806        x = a.to(device="cuda", dtype=torch.float16)
807        self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2)))
808
809    def test_mean_fp16(self):
810        x = torch.ones(65536, device="cuda", dtype=torch.float16)
811        self.assertEqual(x.mean(), 1)
812
813        x = torch.ones(65536, device="cuda", dtype=torch.float16)
814        self.assertEqual(x.mean(dtype=torch.float32), 1)
815
816    def test_prod_large(self):
817        # tests global reduction (should_global_reduce = true) in case of non-zero identity element
818        x = torch.ones(240000, device="cuda", dtype=torch.float32)
819        self.assertEqual(x.prod(), 1)
820
821        # test for complex types. Note 240k is divisible by 4
822        for dtype in [torch.cfloat, torch.cdouble]:
823            x = torch.ones(240000, device="cuda", dtype=dtype) * (0 + 1j)
824            self.assertEqual(x.prod(), 1)
825
826    def test_multinomial_ext(self):
827        # Test two corner cases from older PyTorch (Issue #4858)
828        freqs = torch.cuda.FloatTensor(
829            [
830                0.0,
831                0.0,
832                0.0,
833                0.0,
834                0.0,
835                0.0,
836                0.0,
837                0.0,
838                0.0,
839                0.03178183361887932,
840                0.027680952101945877,
841                0.033176131546497345,
842                0.046052902936935425,
843                0.07742464542388916,
844                0.11543981730937958,
845                0.14148041605949402,
846                0.15784293413162231,
847                0.13180233538150787,
848                0.08271478116512299,
849                0.049702685326337814,
850                0.027557924389839172,
851                0.018125897273421288,
852                0.011851548217236996,
853                0.010252203792333603,
854                0.007422595750540495,
855                0.005372154992073774,
856                0.0045109698548913,
857                0.0036087757907807827,
858                0.0035267581697553396,
859                0.0018864056328311563,
860                0.0024605290964245796,
861                0.0022964938543736935,
862                0.0018453967059031129,
863                0.0010662291897460818,
864                0.0009842115687206388,
865                0.00045109697384759784,
866                0.0007791675161570311,
867                0.00020504408166743815,
868                0.00020504408166743815,
869                0.00020504408166743815,
870                0.00012302644609007984,
871                0.0,
872                0.00012302644609007984,
873                4.100881778867915e-05,
874                0.0,
875                0.0,
876                0.0,
877                0.0,
878                0.0,
879                0.0,
880            ]
881        )
882
883        torch.cuda.manual_seed(11042)
884        sample = torch.multinomial(freqs, 1000, True)
885        self.assertNotEqual(freqs[sample].min(), 0)
886
887        p = torch.zeros(3421, 2, device="cuda", dtype=torch.float)
888        p[:, 1] = 1
889        torch.cuda.manual_seed(5214)
890        r = torch.multinomial(p, 1)
891        self.assertNotEqual(r.min().item(), 0)
892
893        # test corner case from Issue #13867
894        torch.cuda.manual_seed(33)
895        probs = torch.randn(1000000, device="cuda").clamp(min=0) * 3e-5
896        samples = probs.multinomial(1000000, replacement=True)
897        self.assertGreater(probs[samples].min().item(), 0)
898
899    def _spawn_test_multinomial_invalid_probs_cuda(self, probs):
900        import subprocess
901
902        try:
903            p = subprocess.Popen(
904                [
905                    sys.executable,
906                    "-c",
907                    f"""\
908import sys
909import torch
910from torch import inf, nan
911try:
912    with torch.random.fork_rng(devices=[0]):
913        torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True)
914        torch.cuda.synchronize()
915    sys.exit(-1) # Should not be reached
916except RuntimeError as e:
917    sys.exit(-2)
918""",
919                ],
920                stdout=subprocess.PIPE,
921                stderr=subprocess.PIPE,
922                universal_newlines=True,
923            )
924            out, err = p.communicate(timeout=10)
925            p.wait(timeout=10)
926        except subprocess.TimeoutExpired as e:
927            p.kill()
928            out, err = p.communicate()
929        expected_messages = [
930            "device-side assert triggered",  # CUDA
931            "Assertion",  # CUDA
932            "HSA_STATUS_ERROR_EXCEPTION",  # ROCm
933            "Device-side assertion",  # ROCm
934        ]
935        self.assertTrue(any(msg in out or msg in err for msg in expected_messages))
936
937    @slowTest
938    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
939    @unittest.skipIf(
940        NO_MULTIPROCESSING_SPAWN,
941        "Disabled for environments that \
942                     don't support multiprocessing with spawn start method",
943    )
944    def test_multinomial_invalid_probs_cuda(self):
945        self._spawn_test_multinomial_invalid_probs_cuda([1.0, -1.0, 1.0])
946        self._spawn_test_multinomial_invalid_probs_cuda([1.0, inf, 1.0])
947        self._spawn_test_multinomial_invalid_probs_cuda([1.0, -inf, 1.0])
948        self._spawn_test_multinomial_invalid_probs_cuda([1.0, 1.0, nan])
949
950    @staticmethod
951    def _mute_init():
952        os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
953
954    def _spawn_method(self, method, arg):
955        ctx = torch.multiprocessing.get_context("spawn")
956        with ctx.Pool(1, initializer=self._mute_init) as pool:
957            errors = pool.map(method, [arg])
958            for e in errors:
959                if "device-side assert triggered" not in str(e):
960                    self.fail(e)
961
962    @staticmethod
963    def _test_index_bounds_cuda(idx):
964        x = torch.arange(10, device="cuda")
965        try:
966            y = x[torch.tensor([idx])]
967            return f"x[torch.tensor([{idx})]={y}"
968        except RuntimeError as err:
969            return err
970
971    @slowTest
972    @unittest.skipIf(
973        NO_MULTIPROCESSING_SPAWN,
974        "Disabled for environments that \
975                     don't support multiprocessing with spawn start method",
976    )
977    @skipIfRocm
978    def test_index_out_of_bounds_exception_cuda(self):
979        test_method = TestCuda._test_index_bounds_cuda
980        # Test in-bound access works fine
981        self.assertEqual(
982            test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')"
983        )
984        # Test that indexing out of bounds causes assert
985        self._spawn_method(test_method, 11)
986
987    @slowTest
988    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
989    @serialTest()
990    def test_huge_index(self):
991        src = torch.empty(15000000, 45, device="cuda", dtype=torch.long).random_(
992            0, 2**22
993        )
994        idx = torch.randperm(src.shape[0], device="cuda")
995        res = src[idx]
996        res_cpu = src.cpu()[idx.cpu()]
997        self.assertEqual(res.cpu(), res_cpu)
998
999    def test_randint_randomness_for_large_range(self) -> None:
1000        # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox
1001        # offsets were not calculated correctly, resulting in reused random states.
1002        # See https://github.com/pytorch/pytorch/issues/125224
1003        size = 1_000_000
1004        high = 6_000_000_000  # Keep this above 2**32
1005
1006        def run(dev: torch.device) -> int:
1007            # Measure how many unique numbers are generated in 2 consecutive calls to randint. If random states are
1008            # reused, this will yield fewer unique numbers.
1009            gen = torch.Generator(device=dev)
1010            gen.manual_seed(0)
1011            t1 = torch.randint(
1012                0, high, [size], device=dev, generator=gen, dtype=torch.int64
1013            )
1014            t2 = torch.randint(
1015                0, high, [size], device=dev, generator=gen, dtype=torch.int64
1016            )
1017            return torch.stack([t1, t2]).unique().shape[0]
1018
1019        # Use CPU as reference. The results should not deviate too much.
1020        assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000
1021
1022    @parametrize("dtype", [torch.float32, torch.double])
1023    def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None:
1024        # Test if random states do not overlap between consecutive rand/randn calls.
1025        # See https://github.com/pytorch/pytorch/issues/125224
1026
1027        def run(func, dev: torch.device, dtype: torch.dtype) -> int:
1028            # Measure how many unique numbers are generated in 2 consecutive calls. If random states are
1029            # reused, this will yield fewer unique numbers.
1030            size = 1000000
1031            gen = torch.Generator(device=dev)
1032            gen.manual_seed(0)
1033            t1 = func((size,), device=dev, generator=gen, dtype=dtype)
1034            t2 = func((size,), device=dev, generator=gen, dtype=dtype)
1035            return torch.stack([t1, t2]).unique().shape[0]
1036
1037        # Use CPU as reference. The results should not deviate too much.
1038        for func in [torch.rand, torch.randn]:
1039            deviation = abs(
1040                run(func, torch.device("cuda"), dtype)
1041                - run(func, torch.device("cpu"), dtype)
1042            )
1043            assert deviation < 50_000, deviation
1044
1045    def test_min_max_inits(self):
1046        # Testing if THC_reduceAll received the correct index initialization.
1047        # This affects the result of THC_reduceAll operations at extreme values
1048        x = torch.cuda.ByteTensor([0])
1049        y = torch.cuda.ByteTensor([255])
1050        expected = torch.cuda.LongTensor([0])[0]
1051
1052        _, v = x.max(dim=0)
1053        self.assertEqual(v, expected)
1054
1055        _, v = y.min(dim=0)
1056        self.assertEqual(v, expected)
1057
1058    def test_nvtx(self):
1059        # Just making sure we can see the symbols
1060        torch.cuda.nvtx.range_push("foo")
1061        torch.cuda.nvtx.mark("bar")
1062        torch.cuda.nvtx.range_pop()
1063        range_handle = torch.cuda.nvtx.range_start("range_start")
1064        torch.cuda.nvtx.range_end(range_handle)
1065
1066    def test_bincount_ext(self):
1067        # ensure CUDA code coverage
1068        input_size = (100000,)
1069        w = torch.randn(input_size, dtype=torch.double, device="cuda")
1070        w_cpu = w.cpu()
1071        # test shared memory impl
1072        t = torch.randint(50, input_size, dtype=torch.int8, device="cuda")
1073        self.assertEqual(t.cpu().bincount(), t.bincount())
1074        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1075        # test global memory impl
1076        #   see `CUDAHistogramMemoryType` in SummaryOps.cu
1077        #   50000 * sizeof(int64_t) == 390 KiB, which should exceed smem of any known GPU
1078        t = torch.randint(50000, input_size, dtype=torch.int64, device="cuda")
1079        self.assertEqual(t.cpu().bincount(), t.bincount())
1080        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1081
1082        t = torch.zeros([10], dtype=torch.int32, device="cuda")
1083        # 35488 * 65536 as int32 would cause overflow to negative value
1084        # giving negative bin offset
1085        t[0] = 35488
1086        counted = t.bincount(minlength=65536)
1087        self.assertEqual(torch.sum(counted), 10)
1088
1089    def test_tiny_half_norm_(self):
1090        a = torch.arange(25).cuda().float()
1091        a /= 100000000
1092        b = a.half()
1093        self.assertGreater(b.norm().item(), 0)
1094
1095    def test_norm_type_conversion(self):
1096        a = torch.ones(65536).cuda().half()
1097        self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
1098
1099    def test_cuda_memory_leak_detection_propagates_errors(self):
1100        with self.assertRaisesRegex(
1101            RuntimeError, r"The size of tensor a \(3\) must match"
1102        ):
1103            with self.assertLeaksNoCudaTensors():
1104                x = torch.randn(3, 1, device="cuda")
1105                y = torch.randn(2, 1, device="cuda")
1106                z = x + y
1107
1108    @unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory")
1109    @serialTest()
1110    def test_cuda_kernel_loop_overflow(self):
1111        # Issue #24309: In extreme cases, the loop variable could overflow and continue
1112        # the kernel loop with a negative index, causing a RuntimeError (invalid write):
1113        x = torch.randn(1, 1, 1, 2**30 + 1, dtype=torch.float16, device="cuda")
1114        expected = x[0, 0, 0, 2**30]
1115        y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1116        torch.cuda.synchronize()
1117        self.assertEqual(y[0, 0, 0, 2**30], expected)
1118
1119    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
1120    @gcIfJetson
1121    @serialTest()
1122    def test_cuda_kernel_loop_overflow_large(self):
1123        # Make sure input.numel() > INT_MAX is handled:
1124        x = torch.randn(1, 1, 1, 2**31, dtype=torch.float16, device="cuda")
1125        with self.assertRaisesRegex(RuntimeError, "integer out of range"):
1126            y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1127
1128        # Issue #24309: In extreme cases, the loop variable could overflow and continue
1129        # the kernel loop with a negative index, causing a RuntimeError (invalid write):
1130        x = torch.randn(1, 1, 1, 2**31 - 1, dtype=torch.float16, device="cuda")
1131        expected = x[0, 0, 0, 2**31 - 2]
1132        y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1133        torch.cuda.synchronize()
1134        self.assertEqual(y[0, 0, 0, 2**31 - 2], expected)
1135
1136    # this might create a reference cycle on self...
1137    def _make_multiply_in_stream(self):
1138        class MultiplyInStream(torch.autograd.Function):
1139            @staticmethod
1140            def forward(ctx, x, val):
1141                ctx.val = val
1142                ctx.stream = torch.cuda.current_stream()
1143                return x * val
1144
1145            @staticmethod
1146            def backward(ctx, grad):
1147                self.assertEqual(torch.cuda.current_stream(), ctx.stream)
1148                # delays the operation in the background stream
1149                torch.cuda._sleep(1000 * 5000)
1150                return grad * ctx.val, None
1151
1152        return MultiplyInStream
1153
1154    @skipCUDANonDefaultStreamIf(True)
1155    def test_streaming_backwards_sync(self):
1156        default_stream = torch.cuda.current_stream()
1157        stream = torch.cuda.Stream()
1158
1159        MultiplyInStream = self._make_multiply_in_stream()
1160
1161        # Tests using grads outside the backward() stream context
1162        # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html
1163        x = torch.randn(5, 5, device="cuda", requires_grad=True)
1164        with torch.cuda.stream(stream):
1165            stream.wait_stream(default_stream)
1166            output = MultiplyInStream.apply(x, 2)
1167            output.sum().backward()
1168        # sync needed
1169        default_stream.wait_stream(stream)
1170        self.assertEqual(x.grad, torch.ones_like(x) * 2)
1171        self.assertEqual(torch.cuda.current_stream(), default_stream)
1172
1173        # Tests that using grads in the same stream context as backward()
1174        # is safe regardless what streams bwd ops ran on
1175        bwd_ambient_stream = torch.cuda.Stream()
1176        x = torch.randn(5, 5, device="cuda", requires_grad=True)
1177        with torch.cuda.stream(stream):
1178            stream.wait_stream(default_stream)
1179            output = MultiplyInStream.apply(x, 3)
1180        with torch.cuda.stream(bwd_ambient_stream):
1181            bwd_ambient_stream.wait_stream(stream)
1182            output.sum().backward()
1183            # x was first used on "stream" so its AccumulateGrad leaf should run on "stream".
1184            # The end of backward() should have synced "bwd_ambient_stream" with "stream"
1185            # so it should be safe to use x.grad here without any syncs.
1186            self.assertEqual(x.grad, torch.ones_like(x) * 3)
1187            self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream)
1188
1189    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
1190    @skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190")
1191    def test_streaming_backwards_multiple_streams(self):
1192        MultiplyInStream = self._make_multiply_in_stream()
1193
1194        class StreamModel(torch.nn.Module):
1195            def __init__(self) -> None:
1196                super().__init__()
1197                self.event = torch.cuda.Event()
1198                self.stream0 = torch.cuda.Stream()
1199                self.stream1 = torch.cuda.Stream()
1200
1201            def forward(self, x, x_first_use_on_ambient):
1202                if x_first_use_on_ambient:
1203                    x0 = x.clone()
1204                self.stream0.wait_stream(torch.cuda.current_stream())
1205                self.stream1.wait_stream(torch.cuda.current_stream())
1206                with torch.cuda.stream(self.stream0):
1207                    if not x_first_use_on_ambient:
1208                        x0 = x.clone()
1209                    y0 = MultiplyInStream.apply(x0, 2)
1210                    self.event.record(stream=torch.cuda.current_stream())
1211
1212                with torch.cuda.stream(self.stream1):
1213                    y1 = MultiplyInStream.apply(x, 3)
1214                    self.stream1.wait_event(self.event)
1215                    return y0 + y1
1216
1217        stream = torch.cuda.Stream()
1218
1219        for x_first_use_on_ambient in (True, False):
1220            # the out_of_place=False, iters=1 case stresses if proper syncs are inserted
1221            # when grads are initially None and stolen by backward ops.
1222            for out_of_place, iters in ((True, 1), (False, 1), (False, 5)):
1223                with torch.cuda.stream(stream):
1224                    x = torch.randn(5, 5, device="cuda", requires_grad=True)
1225                    model = StreamModel().cuda()
1226                    x.register_hook(
1227                        lambda grad: self.assertEqual(
1228                            torch.cuda.current_stream(),
1229                            stream if x_first_use_on_ambient else model.stream0,
1230                        )
1231                    )
1232                    for p in model.parameters():
1233                        self.assertTrue(p.grad is None)
1234                    for i in range(iters):
1235                        loss = model(x, x_first_use_on_ambient).sum()
1236                        if out_of_place:
1237                            x_grad = torch.autograd.grad((loss,), (x,))[0]
1238                        else:
1239                            loss.backward()
1240                # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html
1241                torch.cuda.current_stream().wait_stream(stream)
1242
1243                if out_of_place:
1244                    self.assertEqual(x_grad, torch.ones_like(x) * 5 * iters)
1245                else:
1246                    self.assertEqual(x.grad, torch.ones_like(x) * 5 * iters)
1247
1248    def test_streaming_backwards_sync_graph_root(self):
1249        # This function tests if bwd ops running on a side stream properly sync with the GraphRoot.
1250        # The potential bug it targets is a race condition. The test uses multiple trials and
1251        # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail,
1252        # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free,
1253        # but failure does guarantee there is a bug.
1254        fwd_bwd_op_stream = torch.cuda.Stream()
1255        bwd_ambient_stream = torch.cuda.Stream()
1256        # We need these streams to be different otherwise the test is meaningless.
1257        self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream)
1258
1259        size = int(1e3)
1260
1261        a = torch.full((size,), 2.0, device="cuda", requires_grad=True)
1262        b = torch.full((size,), 3.0, device="cuda", requires_grad=True)
1263
1264        # I don't think we need any manual record_streams below.
1265        # a and b remain in scope for the entire test.
1266        # c and grad remain in scope for each iteration, and there's a full sync between iterations.
1267        for trial in range(5):
1268            torch.cuda.synchronize()
1269            a.grad = b.grad = None
1270            with torch.cuda.stream(fwd_bwd_op_stream):
1271                c = a * b
1272
1273            with torch.cuda.stream(bwd_ambient_stream):
1274                torch.cuda.synchronize()
1275                # Long-running dummy kernel on bwd_ambient_stream delays filling of grad
1276                torch.cuda._sleep(int(50 * get_cycles_per_ms()))
1277                # Fills grad on bwd_ambient_stream
1278                grad = torch.full((size,), float(trial + 1), device="cuda")
1279
1280                # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if
1281                # bwd ops don't sync with bwd_ambient_stream before consuming grad.
1282                torch.autograd.backward(tensors=c, grad_tensors=grad)
1283
1284                # See https://github.com/pytorch/pytorch/issues/47028
1285                # assertEquals below run on bwd_ambient_stream, so this test may also fail
1286                # if backward() fails to sync with bwd_ambient_stream at the end.
1287                # Synchronizing here works around the issue until a proper fix can be made.
1288                torch.cuda.synchronize()
1289                with torch.no_grad():
1290                    self.assertEqual(a.grad, grad * b)
1291                    self.assertEqual(b.grad, grad * a)
1292
1293    def test_streaming_backwards_callback(self):
1294        # Tests if autograd callbacks sync properly with respect to leaf streams and
1295        # the user-facing stream surrounding backward(). If it fails, first suspect is
1296        # sync logic where  "final_callbacks_" are called in torch/csrc/autograd/engine.cpp
1297        MultiplyInStream = self._make_multiply_in_stream()
1298
1299        size = int(1e3)
1300        a = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1301        b = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1302
1303        s0 = torch.cuda.Stream()
1304        s1 = torch.cuda.Stream()
1305        s2 = torch.cuda.Stream()
1306
1307        stash = []
1308
1309        # sets up a nontrivial structure of leaf streams
1310        s0.wait_stream(torch.cuda.current_stream())
1311        with torch.cuda.stream(s0):
1312            c = MultiplyInStream.apply(a, 2)
1313
1314        s1.wait_stream(torch.cuda.current_stream())
1315        with torch.cuda.stream(s1):
1316            d = MultiplyInStream.apply(b, 3)
1317            s1.wait_stream(s0)
1318            e = c * d
1319
1320            def clone_leaf_grads():
1321                stash.append(a.grad.clone())
1322                stash.append(b.grad.clone())
1323
1324            # Use a hook on e to install the callback
1325            e.register_hook(
1326                lambda grad: torch.autograd.Variable._execution_engine.queue_callback(
1327                    clone_leaf_grads
1328                )
1329            )
1330
1331        s2.wait_stream(s1)
1332        with torch.cuda.stream(s2):
1333            e.sum().backward()
1334            # The autograd engine should sync s2 with all leaf streams then run the callback clone_leaf_grads on s2.
1335            # If those things happened properly, checking the values of the cloned grads on s2 should be safe:
1336            self.assertEqual(stash[0], torch.full_like(a, 6))
1337            self.assertEqual(stash[1], torch.full_like(a, 6))
1338
1339    @unittest.skipIf(
1340        TEST_WITH_ROCM,
1341        "In ROCm, kernel asserts are disabled due to performance overhead",
1342    )
1343    def test_fixed_cuda_assert_async(self):
1344        with self.assertRaisesRegex(
1345            RuntimeError, "Boolean value of Tensor with no values is ambiguous"
1346        ):
1347            torch._assert_async(torch.tensor([], device="cuda"))
1348        with self.assertRaisesRegex(
1349            RuntimeError,
1350            "Boolean value of Tensor with more than one value is ambiguous",
1351        ):
1352            torch._assert_async(torch.tensor([0, 0], device="cuda"))
1353
1354        torch._assert_async(torch.tensor(1, device="cuda"))
1355        torch._assert_async(torch.tensor(0.1, device="cuda"))
1356        torch._assert_async(torch.tensor(-0.1, device="cuda"))
1357        torch._assert_async(torch.tensor(True, device="cuda"))
1358        torch._assert_async(torch.tensor(0 + 0.1j, device="cuda"))
1359
1360        fail_stmts = [
1361            "torch._assert_async(torch.tensor(0, device='cuda'))",
1362            "torch._assert_async(torch.tensor(0.0, device='cuda'))",
1363            "torch._assert_async(torch.tensor(False, device='cuda'))",
1364            "torch._assert_async(torch.tensor(0 + 0j, device='cuda'))",
1365        ]
1366
1367        import subprocess
1368
1369        for stmt in fail_stmts:
1370            with self.subTest(stmt=stmt):
1371                r = subprocess.call(
1372                    [
1373                        sys.executable,
1374                        "-c",
1375                        f"""\
1376import torch
1377
1378{stmt}
1379torch.cuda.synchronize()
1380""",
1381                    ]
1382                )
1383                self.assertTrue(r != 0)
1384
1385    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL")
1386    def test_cublas_multiple_threads_same_device(self):
1387        # Note, these parameters should be very carefully tuned
1388        # Too small number makes it hard for the racing condition
1389        # to happen, while too large number sometimes cause hang
1390        size = 1024
1391        num_threads = 2
1392        trials = 3
1393        test_iters = 100
1394
1395        weight = torch.ones((size, size), device="cuda")
1396        results = {}
1397        barrier = threading.Barrier(num_threads)
1398
1399        def _worker(t):
1400            my_stream = torch.cuda.Stream()
1401            # Hard sync so we don't need to worry about creating and using tensors
1402            # across streams or the fact that default streams are thread-local.
1403            # Those issues are not the target of this test.
1404            torch.cuda.synchronize()
1405            # Line up threads to increase likelihood of race conditions.
1406            barrier.wait()
1407            with torch.cuda.stream(my_stream):
1408                for i in range(test_iters):
1409                    # If all threads are sharing the same cublas handle,
1410                    # the following sequence may occur:
1411                    # thread 0 calls cublasSetStream()
1412                    # thread 1 calls cublasSetStream()
1413                    # thread 0 launches its raw gemm, which it thinks is in
1414                    #          its own stream, but is actually in thread 1's stream.
1415                    # thread 0 enqueues its div_, which IS is its own stream,
1416                    #          but actually now races with its gemm.
1417                    results[t] = torch.mm(results[t], weight)
1418                    results[t].div_(float(size))
1419            torch.cuda.synchronize()
1420
1421        for _ in range(trials):
1422            for t in range(num_threads):
1423                results[t] = torch.ones((size, size), device="cuda")
1424
1425            threads = [
1426                threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1427            ]
1428
1429            for thread in threads:
1430                thread.start()
1431            for thread in threads:
1432                thread.join()
1433
1434            for t in range(num_threads):
1435                self.assertEqual(results[t].sum().item(), size * size)
1436
1437    # Test is flaky on Windows (https://github.com/pytorch/pytorch/issues/57401)
1438    @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows (see issue 57401)")
1439    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1440    @skipIfRocm
1441    def test_cudnn_multiple_threads_same_device(self):
1442        # This function is intended to test the lazy creation and reuse of per-thread
1443        # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp.
1444        # Failure here likely indicates something wrong with that logic.
1445        weight = torch.ones((1, 1, 2, 2), device="cuda")
1446
1447        results = {}
1448
1449        num_threads = 2
1450        trials = 3
1451        test_iters = 1000
1452        barrier = threading.Barrier(num_threads)
1453
1454        with torch.backends.cudnn.flags(enabled=True):
1455
1456            def _worker(t):
1457                my_stream = torch.cuda.Stream()
1458                # Hard sync so we don't need to worry about creating and using tensors
1459                # across streams or the fact that default streams are thread-local.
1460                # Those issues are not the target of this test.
1461                torch.cuda.synchronize()
1462                # Line up threads to increase likelihood of race conditions.
1463                barrier.wait()
1464                with torch.cuda.stream(my_stream):
1465                    for _ in range(test_iters):
1466                        # If all threads are sharing the same cudnn handle,
1467                        # the following sequence may occur:
1468                        # thread 0 calls setCuDNNStreamToCurrent()
1469                        # thread 1 calls setCuDNNStreamToCurrent()
1470                        # thread 0 launches its raw convolution, which it thinks is in
1471                        #          its own stream, but is actually in thread 1's stream.
1472                        # thread 0 enqueues its div_, which IS is its own stream,
1473                        #          but now races with its convolution.
1474                        results[t] = torch.nn.functional.conv2d(
1475                            results[t], weight, padding=0
1476                        )
1477                        results[t].div_(4.0)
1478                torch.cuda.synchronize()
1479
1480            for _ in range(trials):
1481                for t in range(num_threads):
1482                    results[t] = torch.ones((1, 1, 2048, 2048), device="cuda")
1483
1484                threads = [
1485                    threading.Thread(target=_worker, args=(t,))
1486                    for t in range(num_threads)
1487                ]
1488
1489                for thread in threads:
1490                    thread.start()
1491                for thread in threads:
1492                    thread.join()
1493
1494                for t in range(num_threads):
1495                    self.assertEqual(
1496                        results[t].sum().item(),
1497                        (2048 - test_iters) * (2048 - test_iters),
1498                    )
1499
1500    def test_cusparse_multiple_threads_same_device(self):
1501        size = 1024
1502        num_threads = 2
1503        trials = 3
1504        test_iters = 500
1505
1506        def ones_sparse(size):
1507            a = torch.arange(size, device="cuda")
1508            indices = torch.cartesian_prod(a, a).t()
1509            values = torch.ones(size * size, device="cuda")
1510            return torch.sparse_coo_tensor(indices, values)
1511
1512        weight = ones_sparse(size)
1513        results = {}
1514        barrier = threading.Barrier(num_threads)
1515
1516        def _worker(t):
1517            my_stream = torch.cuda.Stream()
1518            # Hard sync so we don't need to worry about creating and using tensors
1519            # across streams or the fact that default streams are thread-local.
1520            # Those issues are not the target of this test.
1521            torch.cuda.synchronize()
1522            # Line up threads to increase likelihood of race conditions.
1523            barrier.wait()
1524            with torch.cuda.stream(my_stream):
1525                for i in range(test_iters):
1526                    # If all threads are sharing the same cublas handle,
1527                    # the following sequence may occur:
1528                    # thread 0 calls cublasSetStream()
1529                    # thread 1 calls cublasSetStream()
1530                    # thread 0 launches its raw gemm, which it thinks is in
1531                    #          its own stream, but is actually in thread 1's stream.
1532                    # thread 0 enqueues its div_, which IS is its own stream,
1533                    #          but actually now races with its gemm.
1534                    results[t] = weight.mm(results[t])
1535                    results[t].div_(float(size))
1536            torch.cuda.synchronize()
1537
1538        for _ in range(trials):
1539            for t in range(num_threads):
1540                results[t] = torch.ones((size, size), device="cuda")
1541
1542            threads = [
1543                threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1544            ]
1545
1546            for thread in threads:
1547                thread.start()
1548            for thread in threads:
1549                thread.join()
1550
1551            for t in range(num_threads):
1552                self.assertEqual(results[t].sum().item(), size * size)
1553
1554    @slowTest
1555    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
1556    @serialTest()
1557    def test_max_large_axis(self):
1558        x = torch.zeros(2**32, device="cuda", dtype=torch.int8)
1559        x[-1] = 1
1560        val, idx = x.max(0)
1561        self.assertEqual(val, 1)
1562        self.assertEqual(idx, x.shape[0] - 1)
1563
1564    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1565    def test_to_numpy(self):
1566        self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy())
1567
1568    def test_graph_is_current_stream_capturing(self):
1569        self.assertFalse(torch.cuda.is_current_stream_capturing())
1570
1571        if TEST_CUDA and (not TEST_WITH_ROCM):
1572            s = torch.cuda.Stream()
1573            with torch.cuda.stream(s):
1574                g = torch.cuda.CUDAGraph()
1575                self.assertFalse(torch.cuda.is_current_stream_capturing())
1576                g.capture_begin()
1577                self.assertTrue(torch.cuda.is_current_stream_capturing())
1578                g.capture_end()
1579
1580    @unittest.skipIf(
1581        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1582    )
1583    def test_graph_capture_simple(self):
1584        s = torch.cuda.Stream()
1585
1586        with torch.cuda.stream(s):
1587            a = torch.full((1000,), 1, device="cuda")
1588            g = torch.cuda.CUDAGraph()
1589            torch.cuda.empty_cache()
1590            g.capture_begin()
1591            b = a
1592            for _ in range(10):
1593                b = b + 1
1594            g.capture_end()
1595        torch.cuda.current_stream().wait_stream(s)
1596
1597        g.replay()
1598
1599        self.assertTrue(b.sum().item() == 11000.0)
1600
1601    @unittest.skipIf(
1602        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1603    )
1604    def test_graphsafe_set_get_rng_state(self):
1605        # Define a function to create generator states, with optional graph registration
1606        def create_states(generator):
1607            """Initializes generator states and registers them with a CUDA graph if provided."""
1608            # Ensure the CUDA generator is initialized
1609            torch.rand(1, device="cuda")
1610            generator.manual_seed(0)
1611
1612            # Save the current state of the generator
1613            old_state = generator.graphsafe_get_state()
1614            # Create and save a cloned state of the generator
1615            new_state = generator.clone_state()
1616            # Return the original generator and its two states
1617            return generator, old_state, new_state
1618
1619        def register_states_to_graph(generator_state, graph):
1620            generator, old_state, new_state = generator_state
1621            graph.register_generator_state(old_state)
1622            graph.register_generator_state(new_state)
1623
1624        # Define a function to perform specific RNG actions using the generator's states
1625        def perform_random_generation_steps(generator_state):
1626            generator, old_state, new_state = generator_state
1627            random_values = []
1628
1629            # Generate random numbers with the new generator state
1630            generator.graphsafe_set_state(new_state)
1631            random_values.append(torch.rand(5, device="cuda", generator=generator))
1632
1633            # Generate random numbers twice with the old generator state
1634            generator.graphsafe_set_state(old_state)
1635            random_values.extend(
1636                [torch.rand(5, device="cuda", generator=generator) for _ in range(2)]
1637            )
1638
1639            return random_values
1640
1641        # Define a function to retrieve the final offsets of the original and new generator states
1642        def get_final_offsets_of_states(generator_state):
1643            generator, old_state, new_state = generator_state
1644            old_state_offset = old_state.get_offset()
1645            new_state_offset = new_state.get_offset()
1646            return old_state_offset, new_state_offset
1647
1648        # Set up and test a new CUDA generator
1649        generator = torch.Generator(device="cuda")
1650        generator_state = create_states(generator)
1651
1652        # Set up and test the default CUDA generator with a CUDA Graph
1653        g = torch.cuda.CUDAGraph()
1654        s = torch.cuda.Stream()
1655        default_generator = torch.cuda.default_generators[0]
1656        default_generator_state = create_states(default_generator)
1657        register_states_to_graph(default_generator_state, g)
1658
1659        # Perform random number generation within a CUDA graph
1660        with torch.cuda.stream(s):
1661            g.capture_begin()
1662            graphed_random_values = perform_random_generation_steps(
1663                default_generator_state
1664            )
1665            g.capture_end()
1666
1667        # Synchronize the streams and replay the graph
1668        torch.cuda.current_stream().wait_stream(s)
1669        for _ in range(3):
1670            random_values = perform_random_generation_steps(generator_state)
1671            g.replay()
1672            offset = get_final_offsets_of_states(generator_state)
1673            graph_offset = get_final_offsets_of_states(default_generator_state)
1674
1675            # Compare the final offsets of states for both generators to ensure consistency
1676            self.assertTrue(offset == graph_offset)
1677            # Compare the states generated outside and inside the graph
1678            self.assertEqual(random_values, graphed_random_values)
1679
1680    @unittest.skipIf(
1681        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1682    )
1683    def test_memory_stats_of_multiple_generators_and_graphs(self):
1684        # Function to clear CUDA cache and collect garbage
1685        def clear_cuda_cache():
1686            gc.collect()
1687            torch.cuda.empty_cache()
1688
1689        # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph.
1690        def simple_graph_task(graph):
1691            s = torch.cuda.Stream()
1692            with torch.cuda.stream(s):
1693                graph.capture_begin()
1694                torch.rand(1, device="cuda")
1695                graph.capture_end()
1696            torch.cuda.current_stream().wait_stream(s)
1697            graph.replay()  # Replays the captured operations
1698
1699        def get_memory_stats():
1700            stats = torch.cuda.memory_stats()
1701            num_blocks = stats["active.all.current"]
1702            total_size = stats["active_bytes.all.current"]
1703            return num_blocks, total_size
1704
1705        def test(num_graphs, num_generators):
1706            baseline = get_memory_stats()
1707            baseline_num_blocks, baseline_total_size = baseline
1708
1709            # Allocate CUDA graphs
1710            graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)]
1711
1712            # Allocate and manage generator states
1713            default_generator = torch.cuda.default_generators[0]
1714            generators = [default_generator.graphsafe_get_state()]
1715
1716            # Starts from 1 as one state is already added
1717            for _ in range(1, num_generators):
1718                generators.append(default_generator.clone_state())
1719
1720            for graph in graphs:
1721                for generator_state in generators:
1722                    graph.register_generator_state(generator_state)
1723                simple_graph_task(graph)
1724
1725            # Assert conditions after graph tasks
1726            num_blocks, total_size = get_memory_stats()
1727            # The allocated blocks should only be proportional to the number of generators
1728            expected_blocks_diff = 2 * num_generators
1729            expected_size_diff = 2 * 512 * num_generators  # Each block's size is 512
1730
1731            self.assertTrue(
1732                (num_blocks - baseline_num_blocks) == expected_blocks_diff,
1733                "Unexpected number of active blocks.",
1734            )
1735            self.assertTrue(
1736                (total_size - baseline_total_size) == expected_size_diff,
1737                "Unexpected total memory size.",
1738            )
1739
1740            # Cleanup graphs and clear CUDA cache
1741            while graphs:
1742                graph = graphs.pop()
1743                del graph
1744            clear_cuda_cache()
1745
1746            # Assert that memory stats return to baseline after cleanup
1747            self.assertTrue(
1748                get_memory_stats() == baseline,
1749                "Memory stats do not match baseline after cleanup.",
1750            )
1751
1752        # Running the test function with different parameters
1753        test(1, 1)
1754        test(3, 2)
1755        test(10, 20)
1756
1757    @unittest.skipIf(
1758        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1759    )
1760    def test_graph_capture_reset_recapture(self):
1761        s = torch.cuda.Stream()
1762
1763        with torch.cuda.stream(s):
1764            a = torch.full((1000,), 1, device="cuda")
1765            g = torch.cuda.CUDAGraph()
1766            torch.cuda.empty_cache()
1767            g.capture_begin()
1768            b = a
1769            for _ in range(10):
1770                b = b + 1
1771            g.capture_end()
1772        torch.cuda.current_stream().wait_stream(s)
1773
1774        g.replay()
1775
1776        self.assertTrue(b.sum().item() == 11000.0)
1777
1778        g.reset()
1779
1780        with torch.cuda.stream(s):
1781            g.capture_begin()
1782            b.fill_(2.0)
1783            for _ in range(10):
1784                b = b + 2
1785            g.capture_end()
1786        torch.cuda.current_stream().wait_stream(s)
1787
1788        g.replay()
1789        self.assertTrue(b.sum().item() == 22000.0)
1790
1791        g.reset()
1792        del g
1793
1794    @unittest.skipIf(
1795        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1796    )
1797    def test_graph_debugdump(self):
1798        torch.cuda.empty_cache()
1799        x = torch.randn(10240000, device="cuda")
1800        y = torch.rand_like(x)
1801        g = torch.cuda.CUDAGraph()
1802        g.enable_debug_mode()
1803        s0 = torch.cuda.Stream()
1804        s1 = torch.cuda.Stream()
1805        s0.wait_stream(torch.cuda.current_stream())
1806        with torch.cuda.stream(s0):
1807            g.capture_begin()
1808            z = x + y
1809            with torch.cuda.stream(s1):
1810                s1.wait_stream(s0)
1811                w = z + y
1812            s0.wait_stream(s1)
1813            g.capture_end()
1814        s0.synchronize()
1815        torch.cuda.synchronize()
1816        with tempfile.TemporaryDirectory() as tempdir:
1817            g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot"))
1818
1819    @unittest.skipIf(
1820        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1821    )
1822    def test_graph_error(self):
1823        # We need to run this test in a separate thread as the error we trigger
1824        # puts the cuda context in a bad state
1825        script = """
1826import torch
1827
1828g = torch.cuda.CUDAGraph()
1829try:
1830    g.capture_begin()
1831except RuntimeError as e:
1832    if "CUDA graphs must be captured on a non-default stream." in str(e):
1833        exit(0)
1834    else:
1835        exit(1)
1836exit(2)
1837"""
1838        try:
1839            a = subprocess.check_output(
1840                [sys.executable, "-c", script],
1841                stderr=subprocess.STDOUT,
1842                # On Windows, opening the subprocess with the default CWD makes `import torch`
1843                # fail, so just set CWD to this script's directory
1844                cwd=os.path.dirname(os.path.realpath(__file__)),
1845            )
1846        except subprocess.CalledProcessError as e:
1847            if e.returncode == 1:
1848                self.assertTrue(
1849                    False,
1850                    "Error raise by starting capture without a stream is not the expected one",
1851                )
1852            elif e.returncode == 2:
1853                self.assertTrue(
1854                    False,
1855                    "Error raised by starting capture without a stream was not caught",
1856                )
1857
1858    @unittest.skipIf(
1859        (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11,
1860        "CUDA >= 11.0 required for graphs",
1861    )
1862    def test_graph_warn_if_has_zero_nodes(self):
1863        with warnings.catch_warnings(record=True) as caught:
1864            g = torch.cuda.CUDAGraph()
1865            s = torch.cuda.Stream()
1866            with torch.cuda.stream(s):
1867                g.capture_begin()
1868                g.capture_end()
1869        self.assertTrue(
1870            any("The CUDA Graph is empty" in str(w.message) for w in caught)
1871        )
1872
1873    @unittest.skipIf(
1874        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1875    )
1876    @unittest.skipIf(
1877        IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
1878    )
1879    def test_graph_capture_oom(self):
1880        oom_regex = (
1881            "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
1882        )
1883        with self.assertRaisesRegex(RuntimeError, oom_regex):
1884            with torch.cuda.graph(torch.cuda.CUDAGraph()):
1885                torch.zeros(2**40, device="cuda")
1886
1887    @unittest.skipIf(
1888        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1889    )
1890    @serialTest()
1891    def test_repeat_graph_capture_cublas_workspace_memory(self):
1892        (x, y, z) = 1024, 512, 64
1893        a = torch.rand((x, y), device="cuda")
1894        b = torch.rand((y, z), device="cuda")
1895
1896        # warmup
1897        torch.mm(a, b)
1898
1899        free_bytes_before, total_bytes = torch.cuda.mem_get_info()
1900        used_gb_before = (total_bytes - free_bytes_before) / 1e9
1901
1902        for i in range(100):
1903            torch_graph = torch.cuda.CUDAGraph()
1904            with torch.cuda.graph(torch_graph):
1905                torch.mm(a, b)
1906            torch_graph.replay()
1907
1908        free_bytes_after, _ = torch.cuda.mem_get_info()
1909        used_gb_after = (total_bytes - free_bytes_after) / 1e9
1910
1911        self.assertFalse(used_gb_before + 0.1 < used_gb_after)
1912
1913    @unittest.skipIf(
1914        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1915    )
1916    def test_graph_rng_functional(self):
1917        ops_with_kwargs = (
1918            (torch.nn.functional.dropout, {"p": 0.1}),
1919            (torch.nn.functional.rrelu, {"training": True}),
1920        )
1921        size = 10000
1922
1923        def run(op, kwargs):
1924            a = torch.randn((size,), device="cuda", dtype=torch.float)
1925
1926            # Control
1927            torch.cuda.manual_seed(5)
1928            eager_out = a
1929            for _ in range(6):
1930                eager_out = op(eager_out, **kwargs)
1931
1932            graph_in = a.clone()
1933            stream = torch.cuda.Stream()
1934            stream.wait_stream(torch.cuda.current_stream())
1935            with torch.cuda.stream(stream):
1936                torch.cuda.manual_seed(5)
1937
1938                g = torch.cuda.CUDAGraph()
1939                torch.cuda.empty_cache()
1940                g.capture_begin()
1941                graph_out = graph_in
1942                for _ in range(2):
1943                    graph_out = op(graph_out, **kwargs)
1944                g.capture_end()
1945            torch.cuda.current_stream().wait_stream(stream)
1946
1947            # Runs a graphed->eager->graphed sequence of RNG ops.
1948            # replay() plays 2 invocations of the op, so the sequence has 6
1949            # invocations total, matching Control.
1950            # replay() reads from graph_in and writes to graph_out.
1951            g.replay()
1952            out = op(graph_out, **kwargs)
1953            out = op(out, **kwargs)
1954            graph_in.copy_(out)
1955            g.replay()
1956
1957            # If replay() updated RNG state correctly, graph_out
1958            # should now hold data equal to eager_out.
1959            try:
1960                self.assertEqual(eager_out, graph_out)
1961            except Exception as e:
1962                raise RuntimeError("Failed on ", op) from e
1963
1964            # Do the same operations varying seeds
1965            seeds = [6, 128, 9999]
1966
1967            for seed in seeds:
1968                torch.cuda.manual_seed(seed)
1969                graph_in.copy_(a)
1970                for _ in range(3):
1971                    g.replay()
1972
1973                # If the random seed was not updated then the graph would
1974                # generate the same output as in previous check.
1975                try:
1976                    self.assertNotEqual(eager_out, graph_out)
1977                except Exception as e:
1978                    raise RuntimeError("Failed on ", op) from e
1979
1980                # Now repeat the same operations in non-graphed mode.
1981                torch.cuda.manual_seed(seed)
1982                for _ in range(3):
1983                    eager_out.copy_(a)
1984                    eager_out = op(eager_out, **kwargs)
1985                    eager_out = op(eager_out, **kwargs)
1986
1987                # In the end, graph_out and eager_out must be equal
1988                # as they went under the same set of operations.
1989                try:
1990                    self.assertEqual(eager_out, graph_out)
1991                except Exception as e:
1992                    raise RuntimeError("Failed on ", op) from e
1993
1994            # We hold references to all tensors used across streams up til this sync,
1995            # so no need to call record_stream on those tensors.
1996            torch.cuda.synchronize()
1997
1998        for op, kwargs in ops_with_kwargs:
1999            run(op, kwargs)
2000
2001    @unittest.skipIf(
2002        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2003    )
2004    def test_graph_rng_distributions(self):
2005        size = 10000
2006        input = torch.rand((size,), device="cuda", dtype=torch.float)
2007        alloc = torch.empty((size,), device="cuda", dtype=torch.float)
2008
2009        # Torch ops to test with sample args (tuple) and kwargs (dict)
2010        torch_with_args = (
2011            ("bernoulli", (input.clone(),), {}),
2012            # multinomial uses some uncapturable CUDA calls.
2013            # TODO: reenable multinomial tests if/when the implementation is capturable.
2014            # ("multinomial", (input.clone(), size, True), {}),
2015            # ("multinomial", (input.clone(), size // 2, False), {}),
2016            # TODO: reenable normal test, where std is a device
2017            # tensor, when graph test failures are fixed
2018            # ("normal", (input.clone() + 1, input.clone()), {}),
2019            ("normal", (input.clone() + 1, 1.0), {}),
2020            ("poisson", (input.clone(),), {}),
2021            ("rand", (size,), {"device": "cuda", "dtype": torch.float}),
2022            ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}),
2023            ("randn", (size,), {"device": "cuda", "dtype": torch.float}),
2024        )
2025
2026        # Tensor methods to test with sample args (tuple)
2027        tensor_with_args = (
2028            ("bernoulli_", (input.clone(),)),
2029            ("cauchy_", ()),
2030            ("exponential_", ()),
2031            ("geometric_", (0.3,)),
2032            ("log_normal_", ()),
2033            ("normal_", ()),
2034            ("random_", ()),
2035            ("uniform_", ()),
2036        )
2037
2038        def run(module, op, args, kwargs):
2039            torch.cuda.manual_seed(5)
2040
2041            # Each path runs a dummy op to increment the state a bit before creating controls.
2042            if module == "torch":
2043                dummy = getattr(torch, op)(*args, **kwargs)
2044                control1 = getattr(torch, op)(*args, **kwargs)
2045                control2 = getattr(torch, op)(*args, **kwargs)
2046            else:
2047                dummy = alloc.clone()
2048                control1 = alloc.clone()
2049                control2 = alloc.clone()
2050                getattr(dummy, op)(*args)
2051                getattr(control1, op)(*args)
2052                getattr(control2, op)(*args)
2053
2054            stream = torch.cuda.Stream()
2055            stream.wait_stream(torch.cuda.current_stream())
2056            with torch.cuda.stream(stream):
2057                torch.cuda.manual_seed(5)
2058
2059                g = torch.cuda.CUDAGraph()
2060                torch.cuda.empty_cache()
2061                if module == "torch":
2062                    g.capture_begin()
2063                    t1 = getattr(torch, op)(*args, **kwargs)
2064                    t2 = getattr(torch, op)(*args, **kwargs)
2065                    g.capture_end()
2066                else:
2067                    t1 = alloc.clone()
2068                    t2 = alloc.clone()
2069                    g.capture_begin()
2070                    getattr(t1, op)(*args)
2071                    getattr(t2, op)(*args)
2072                    g.capture_end()
2073            torch.cuda.current_stream().wait_stream(stream)
2074
2075            if not TEST_CUDAMALLOCASYNC:
2076                # Makes sure values haven't been populated yet
2077                # (in other words, makes sure capture didn't actually run ops).
2078                # We can only try this with the native allocator, for which captured
2079                # addresses are already backed by cudaMalloced memory.
2080                # If we try it with cudaMallocAsync, CUDA won't event consider
2081                # the captured addresses allocated until replay(), and if we
2082                # access them before replay() we get IMAs.
2083                try:
2084                    self.assertNotEqual(control1, t1)
2085                    self.assertNotEqual(control2, t2)
2086                except Exception as e:
2087                    raise RuntimeError("Failed on " + module + "." + op) from e
2088
2089            # Set a new seed to check if graph would use it
2090            for seed in [6, 314, 271]:
2091                torch.cuda.manual_seed(seed)
2092                # Runs a dummy op prelude, as for controls, to make sure replay()
2093                # picks up the dummy op's state increment.
2094                if module == "torch":
2095                    dummy = getattr(torch, op)(*args, **kwargs)
2096                    control1 = getattr(torch, op)(*args, **kwargs)
2097                    control2 = getattr(torch, op)(*args, **kwargs)
2098                else:
2099                    getattr(dummy, op)(*args)
2100                    getattr(control1, op)(*args)
2101                    getattr(control2, op)(*args)
2102
2103                torch.cuda.manual_seed(seed)
2104                if module == "torch":
2105                    dummy = getattr(torch, op)(*args, **kwargs)
2106                else:
2107                    getattr(dummy, op)(*args)
2108
2109                # see above comment on TEST_CUDAMALLOCASYNC
2110                if not TEST_CUDAMALLOCASYNC:
2111                    t1.copy_(alloc)
2112                    t2.copy_(alloc)
2113
2114                # Runs RNG ops that fill t1 and t2.
2115                g.replay()
2116
2117                try:
2118                    self.assertEqual(control1, t1)
2119                    self.assertEqual(control2, t2)
2120                except Exception as e:
2121                    raise RuntimeError("Failed on " + module + "." + op) from e
2122
2123            # We hold references to all tensors used across streams up til this sync,
2124            # so no need to call record_stream on those tensors.
2125            torch.cuda.synchronize()
2126
2127        for op_with_args in torch_with_args:
2128            run("torch", *op_with_args)
2129
2130        for meth_with_args in tensor_with_args:
2131            # Adds an empty dict for kwargs, which none of the Tensor methods use
2132            run("Tensor", *(meth_with_args + ({},)))
2133
2134    @unittest.skipIf(
2135        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2136    )
2137    def test_graph_two_successive(self):
2138        torch.cuda.empty_cache()
2139
2140        size = 1000
2141        kSmallBuffer = 2097152
2142
2143        def func_with_temps(t, val):
2144            x = t.clone() + val
2145            y = t.clone() + val
2146            return x + y
2147
2148        s = torch.cuda.Stream()
2149
2150        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2151            g0 = torch.cuda.CUDAGraph()
2152            g1 = torch.cuda.CUDAGraph()
2153
2154            a = torch.ones((size,), device="cuda")
2155
2156            s.wait_stream(torch.cuda.current_stream())
2157            with torch.cuda.stream(s):
2158                g0_args = (
2159                    (torch.cuda.graph_pool_handle(),)
2160                    if share_mem == "via graph_pool_handle()"
2161                    else ()
2162                )
2163                g0.capture_begin(*g0_args)
2164                b = a.clone()
2165                for _ in range(5):
2166                    b = func_with_temps(b, 1)
2167                g0.capture_end()
2168
2169                g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2170                g1.capture_begin(*g1_args)
2171                for _ in range(5):
2172                    b = func_with_temps(b, 1)
2173                g1.capture_end()
2174            torch.cuda.current_stream().wait_stream(s)
2175
2176            # mixes unrelated eager ops with replays
2177            c = a.clone()
2178            for _ in range(2):
2179                c = func_with_temps(c, 3)
2180            g0.replay()
2181            for _ in range(2):
2182                c = func_with_temps(c, 3)
2183            g1.replay()
2184            for _ in range(2):
2185                c = func_with_temps(c, 3)
2186
2187            self.assertEqual(b.sum().item(), size * 3070)
2188            self.assertEqual(c.sum().item(), size * 442)
2189
2190            if not TEST_CUDAMALLOCASYNC:
2191                # These stat checks are specific to the native allocator.
2192                if share_mem != "Don't share":
2193                    self.assertEqual(
2194                        reserved_no_sharing  # noqa: F821
2195                        - torch.cuda.memory_stats()["reserved_bytes.all.current"],
2196                        kSmallBuffer,
2197                    )
2198                else:
2199                    reserved_no_sharing = torch.cuda.memory_stats()[
2200                        "reserved_bytes.all.current"
2201                    ]
2202
2203            del a, b, c, g0, g1
2204            # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
2205            torch.cuda.synchronize()
2206            torch.cuda.empty_cache()
2207
2208    @unittest.skipIf(
2209        (not TEST_CUDA_GRAPH)
2210        or IS_WINDOWS
2211        or (  # appears to still be broken on Windows as of 11.4+
2212            torch.version.cuda
2213            and int(torch.version.cuda.split(".")[0]) == 11
2214            and int(torch.version.cuda.split(".")[1]) < 4
2215        ),
2216        "Graph bindings disallow concurrent replay for CUDA < 11.4, see "
2217        + "https://github.com/pytorch/pytorch/pull/57556",
2218    )
2219    @unittest.skipIf(
2220        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2221    )
2222    def test_graph_concurrent_replay(self):
2223        torch.cuda.empty_cache()
2224
2225        size = 1000000  # largeish to help expose race conditions
2226
2227        def func_with_temps(t, val):
2228            x = t.clone() + val
2229            y = t.clone() + val
2230            return x + y
2231
2232        s = torch.cuda.Stream()
2233
2234        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2235            g0 = torch.cuda.CUDAGraph()
2236            g1 = torch.cuda.CUDAGraph()
2237
2238            s0 = torch.cuda.Stream()
2239            s1 = torch.cuda.Stream()
2240
2241            a = torch.ones((size,), device="cuda")
2242
2243            s.wait_stream(torch.cuda.current_stream())
2244            with torch.cuda.stream(s):
2245                g0_args = (
2246                    (torch.cuda.graph_pool_handle(),)
2247                    if share_mem == "via graph_pool_handle()"
2248                    else ()
2249                )
2250                g0.capture_begin(*g0_args)
2251                b = a.clone()
2252                for _ in range(5):
2253                    b = func_with_temps(b, 1)
2254                g0.capture_end()
2255
2256                g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2257                g1.capture_begin(*g1_args)
2258                c = a.clone()
2259                for _ in range(5):
2260                    c = func_with_temps(c, 2)
2261                g1.capture_end()
2262
2263            # To reproduce data corruption, I need g0 and g1's kernels to run concurrently.
2264            # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead.
2265            # The following pattern helps align device-side execution of g0 and g1's kernels.
2266            torch.cuda.synchronize()
2267            with torch.cuda.stream(s0):
2268                torch.cuda._sleep(1000000)
2269                s1.wait_stream(s0)
2270                g0.replay()
2271            with torch.cuda.stream(s1):
2272                g1.replay()
2273            torch.cuda.current_stream().wait_stream(s0)
2274            torch.cuda.current_stream().wait_stream(s1)
2275
2276            if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"):
2277                # If we used the native allocator and shared mempools,
2278                # we expect the concurrent replays corrupted each other.
2279                self.assertNotEqual(b.sum().item(), size * 94)
2280                self.assertNotEqual(c.sum().item(), size * 156)
2281            else:
2282                # If we EITHER
2283                #   - used the native allocator without sharing mempools, OR
2284                #   - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
2285                # we don't expect memory corruption.
2286                self.assertEqual(b.sum().item(), size * 94)
2287                self.assertEqual(c.sum().item(), size * 156)
2288
2289            del a, b, c, g0, g1
2290            # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them.
2291            torch.cuda.synchronize()
2292            torch.cuda.empty_cache()
2293
2294    @unittest.skipIf(
2295        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2296    )
2297    def test_graph_three_successive(self):
2298        torch.cuda.empty_cache()
2299
2300        size = 1000
2301
2302        s = torch.cuda.Stream()
2303
2304        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2305            a = torch.ones((size,), device="cuda")
2306
2307            g0 = torch.cuda.CUDAGraph()
2308            g1 = torch.cuda.CUDAGraph()
2309            g2 = torch.cuda.CUDAGraph()
2310
2311            s.wait_stream(torch.cuda.current_stream())
2312            with torch.cuda.stream(s):
2313                g0_args = (
2314                    (torch.cuda.graph_pool_handle(),)
2315                    if share_mem == "via graph_pool_handle()"
2316                    else ()
2317                )
2318                g0.capture_begin(*g0_args)
2319                b = a.clone()
2320                c = b + 1
2321                d = b + 2
2322                g0.capture_end()
2323
2324                args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2325
2326                g1.capture_begin(*args)
2327                e = c + 3
2328                del c
2329                g1.capture_end()
2330
2331                g2.capture_begin(*args)
2332                f = d + 4
2333                g2.capture_end()
2334            torch.cuda.current_stream().wait_stream(s)
2335
2336            # Tests that replaying in capture order is valid
2337            g0.replay()
2338            g1.replay()
2339            g2.replay()
2340
2341            self.assertEqual(e.sum().item(), size * 5)
2342            self.assertEqual(f.sum().item(), size * 7)
2343
2344            # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool
2345            g0.replay()
2346            g2.replay()
2347            g1.replay()
2348
2349            expect_corruption = (not TEST_CUDAMALLOCASYNC) and (
2350                share_mem != "Don't share"
2351            )
2352            # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
2353            # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
2354            self.assertEqual(
2355                e.sum().item(), size * (7 + 3) if expect_corruption else size * 5
2356            )
2357            self.assertEqual(f.sum().item(), size * 7)
2358
2359            del a, b, d, e, f, g0, g1, g2
2360            # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them.
2361            torch.cuda.synchronize()
2362            torch.cuda.empty_cache()
2363
2364    @unittest.skipIf(
2365        (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC,
2366        "CUDA >= 11.0 or ROCM >= 5.3 required for graphs",
2367    )
2368    def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
2369        kSmallSize = 1048576
2370        kSmallBuffer = 2097152
2371        kLargeBuffer = 20971520
2372        kMinLargeAlloc = 10485760
2373        kRoundLarge = 2097152
2374
2375        elem = 4
2376
2377        # this was annoying to write but stresses the expectations pretty rigorously
2378        cases = (
2379            (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"),
2380            (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"),
2381            ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"),
2382            (
2383                (kMinLargeAlloc - 512) // elem,
2384                2,
2385                2 * kLargeBuffer,
2386                kLargeBuffer,
2387                "large_pool",
2388            ),
2389            (
2390                (kMinLargeAlloc + 512) // elem,
2391                3,
2392                3
2393                * (
2394                    kRoundLarge
2395                    * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge)
2396                ),
2397                kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge),
2398                "large_pool",
2399            ),
2400        )
2401
2402        stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.")
2403
2404        gc.collect()
2405        torch.cuda.empty_cache()
2406
2407        s = torch.cuda.Stream()
2408
2409        for (
2410            numel,
2411            delta_cudaMallocs,
2412            delta_cudaMalloc_bytes,
2413            delta_cudaMalloc_bytes_post_del_g,
2414            pool_string,
2415        ) in cases:
2416            if pool_string == "small_pool":
2417                delta_active_blocks = 3  # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders
2418                delta_active_bytes = (
2419                    numel * elem + 1024
2420                )  # + 1024 for CUDAGraph's rng seed and offset holders each
2421            else:
2422                delta_active_blocks = 1  # We only check the large pool, which isn't affected by rng offset holder
2423                delta_active_bytes = numel * elem
2424
2425            g = torch.cuda.CUDAGraph()
2426            s.wait_stream(torch.cuda.current_stream())
2427            with torch.cuda.stream(s):
2428                # Allocation stat estimates assume input is created on the same stream as capture_begin()
2429                # (in other words, the same stream silo as the rng offset holder, which is not allocated from the
2430                # capture's private pool).
2431                a = torch.ones((numel,), device="cuda")
2432
2433                precapture_stats = torch.cuda.memory_stats()
2434
2435                g.capture_begin()
2436                b = a.clone()
2437                for _ in range(5):
2438                    b = b.clone() + 1
2439                g.capture_end()
2440            torch.cuda.current_stream().wait_stream(s)
2441
2442            gc.collect()
2443
2444            postcapture_stats = torch.cuda.memory_stats()
2445
2446            expecteds = (
2447                delta_cudaMallocs,
2448                delta_cudaMalloc_bytes,
2449                delta_active_blocks,
2450                delta_active_bytes,
2451            )
2452            # Double checks replay and stats before and after a call to empty_cache
2453            for i in range(2):
2454                for stat, expected in zip(stats_to_check, expecteds):
2455                    stat = stat + pool_string + ".current"
2456                    current = postcapture_stats[stat] - precapture_stats[stat]
2457
2458                    # There will only ever be one expandable segment in each of the small and large pools. The way the
2459                    # bookeeping is done in the allocator means that we never increment the number of segments.
2460                    if self.expandable_segments and "segment" in stat:
2461                        expected = 0
2462                    # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
2463                    # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
2464                    # smaller than the page size
2465                    if (
2466                        self.expandable_segments
2467                        and "reserved" in stat
2468                        and (numel == cases[3][0] or numel == cases[4][0])
2469                    ):
2470                        expected = 2 * kLargeBuffer
2471
2472                    self.assertEqual(
2473                        current,
2474                        expected,
2475                        "Pre to post capture delta of "
2476                        + stat
2477                        + f" = {current}, expected = {expected}, numel = {numel}",
2478                    )
2479
2480                g.replay()
2481                self.assertEqual(b.sum().item(), 6 * numel)
2482                if i == 0:
2483                    torch.cuda.empty_cache()
2484
2485            del g
2486            gc.collect()
2487            torch.cuda.empty_cache()
2488            postdel_stats = torch.cuda.memory_stats()
2489
2490            # Uses graph result b after graph has been deleted
2491            self.assertEqual(b.sum().item(), 6 * numel)
2492
2493            # b should be the only live reference remaining from the graph's private pool
2494            expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem)
2495            for stat, expected in zip(stats_to_check, expecteds):
2496                stat = stat + pool_string + ".current"
2497                current = postdel_stats[stat] - precapture_stats[stat]
2498
2499                # There will only ever be one expandable segment in each of the small and large pools. The way the
2500                # bookeeping is done in the allocator means that we never increment the number of segments.
2501                if self.expandable_segments and "segment" in stat:
2502                    expected = 0
2503                # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
2504                # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
2505                # smaller than the page size
2506                if (
2507                    self.expandable_segments
2508                    and "reserved" in stat
2509                    and numel == cases[3][0]
2510                ):
2511                    expected = 2 * kLargeBuffer
2512                if (
2513                    self.expandable_segments
2514                    and "reserved" in stat
2515                    and numel == cases[4][0]
2516                ):
2517                    expected = kLargeBuffer
2518
2519                self.assertEqual(
2520                    current,
2521                    expected,
2522                    "Pre capture to post graph delete delta of "
2523                    + stat
2524                    + f" = {current}, expected = {expected}, numel = {numel}",
2525                )
2526
2527            # del a, b before the next case is essential, otherwise overwriting a and b in the next case
2528            # can throw off its allocation/deallocation counts.
2529            del a, b
2530            # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
2531            torch.cuda.synchronize()
2532            torch.cuda.empty_cache()
2533
2534    @unittest.skipIf(
2535        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2536    )
2537    def test_graph_record_stream(self):
2538        # Makes sure graph capture defers attempting to reclaim allocations used across streams. See
2539        # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp
2540        torch.cuda.empty_cache()
2541
2542        potential_problem = torch.zeros((3,), device="cuda")
2543        a = torch.zeros((3,), device="cuda")
2544        s0 = torch.cuda.Stream()
2545        s1 = torch.cuda.Stream()
2546        s2 = torch.cuda.Stream()
2547        g = torch.cuda.CUDAGraph()
2548
2549        torch.cuda.synchronize()
2550        with torch.cuda.stream(s0):
2551            potential_problem.record_stream(s0)
2552            torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES)
2553            potential_problem.fill_(1.0)
2554        del potential_problem
2555
2556        with torch.cuda.stream(s1):
2557            g.capture_begin()
2558            # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc
2559            # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life
2560            # event, which will cause the capture to error.
2561            b = a.clone()
2562
2563            # Let's also see what happens if we record_stream on a tensor during capture.
2564            s2.wait_stream(s1)
2565            with torch.cuda.stream(s2):
2566                b.fill_(1.0)
2567                b.record_stream(s2)  # dummy record_stream
2568                del b
2569            s1.wait_stream(s2)
2570            g.capture_end()
2571        torch.cuda.synchronize()
2572
2573        # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
2574        c = torch.zeros((3,), device="cuda")
2575
2576    @skipIfRocm
2577    @unittest.skipIf(
2578        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2579    )
2580    # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize
2581    # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior
2582    # as a memory leak unless we skip the leak check.
2583    @skipCUDAMemoryLeakCheckIf(True)
2584    @serialTest()
2585    def test_graph_cudnn_dropout(self):
2586        # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp.
2587        # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should
2588        # avoid syncing noncapturing streams with captured events or vice versa.
2589        torch.cuda.empty_cache()
2590
2591        model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
2592        x = torch.ones(100, 192, 512, device="cuda")
2593
2594        y = model(x)
2595
2596        g = torch.cuda.CUDAGraph()
2597        s = torch.cuda.Stream()
2598        s.wait_stream(torch.cuda.current_stream())
2599        with torch.cuda.stream(s):
2600            g.capture_begin()
2601            y = model(x)
2602            g.capture_end()
2603        torch.cuda.current_stream().wait_stream(s)
2604
2605        g.replay()
2606
2607        y = model(x)
2608
2609    @unittest.skipIf(
2610        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2611    )
2612    @parametrize(
2613        "with_amp,cache_enabled,allow_unused_input",
2614        [
2615            subtest((False, False, True), decorators=[skipIfRocm]),
2616            subtest((True, False, True), decorators=[skipIfRocm]),
2617            subtest((True, True, True), decorators=[unittest.expectedFailure]),
2618            subtest((False, False, False), decorators=[unittest.expectedFailure]),
2619        ],
2620        name_fn=lambda x, y, z: "{}{}{}".format(
2621            {True: "with_amp", False: "without_amp"}[x],
2622            {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
2623            {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
2624        ),
2625    )
2626    @serialTest()
2627    def test_graph_make_graphed_callables(
2628        self, with_amp, cache_enabled, allow_unused_input
2629    ):
2630        torch.manual_seed(5)
2631        torch.cuda.manual_seed(5)
2632
2633        N, D_in, H, D_out = 640, 4096, 2048, 1024
2634
2635        class MLP1(torch.nn.Module):
2636            def __init__(self, D_in: int, H: int, D_out: int):
2637                super().__init__()
2638                self.net_1 = torch.nn.Sequential(
2639                    torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
2640                ).cuda()
2641                self.net_2 = torch.nn.Sequential(
2642                    torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
2643                ).cuda()
2644
2645            def forward(self, input_dict: dict):
2646                x = input_dict["x"]
2647                return self.net_2(self.net_1(x))
2648
2649        class MLP2(torch.nn.Module):
2650            def __init__(self, D_in: int, H: int, D_out: int):
2651                super().__init__()
2652                self.net_1 = torch.nn.Sequential(
2653                    torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
2654                ).cuda()
2655                self.net_2 = torch.nn.Sequential(
2656                    torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
2657                ).cuda()
2658
2659            def forward(self, x):
2660                return self.net_2(self.net_1(x))
2661
2662        class ParameterlessModule(torch.nn.Module):
2663            def forward(self, x):
2664                idx = (
2665                    torch.arange(x.size(0), device=x.device)
2666                    .view(-1, 1)
2667                    .repeat(1, x.size(1))
2668                )
2669                return {"output": torch.gather(x, 0, idx)}
2670
2671        models = []
2672        for _ in range(2):
2673            model_section1 = MLP1(D_in, H, H).cuda()
2674            model_section2 = MLP2(H, H, D_out).cuda()
2675            model_section3 = ParameterlessModule().cuda()
2676            models.append(
2677                torch.nn.Sequential(model_section1, model_section2, model_section3)
2678            )
2679
2680        model_graphed = models[0]
2681        model_control = models[1]
2682
2683        model_graphed.load_state_dict(model_control.state_dict())
2684
2685        opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1)
2686        opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1)
2687
2688        x = torch.randn(N, D_in, device="cuda")
2689        h = torch.randn(N, H, device="cuda", requires_grad=True)
2690        h2 = torch.randn(N, D_out, device="cuda", requires_grad=True)
2691        unused_input = torch.randn(N, H, device="cuda", requires_grad=True)
2692        y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True)
2693        y = torch.randn(N, D_out, device="cuda")
2694
2695        loss_fn_control = torch.nn.functional.mse_loss
2696        relu_control = torch.nn.functional.relu
2697
2698        # This is a good stress test. It graphs four callables: two Modules and two python functions.
2699        with torch.amp.autocast(
2700            device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2701        ):
2702            (
2703                model_graphed[0],
2704                model_graphed[1],
2705                model_graphed[2],
2706                relu_graphed,
2707                loss_fn_graphed,
2708            ) = torch.cuda.make_graphed_callables(
2709                (
2710                    model_graphed[0],
2711                    model_graphed[1],
2712                    model_graphed[2],
2713                    relu_control,
2714                    loss_fn_control,
2715                ),
2716                (
2717                    ({"x": x, "unused_input": unused_input},),
2718                    (h,),
2719                    (h2,),
2720                    (y_pred,),
2721                    (y_pred, y),
2722                ),
2723                allow_unused_input=allow_unused_input,
2724            )
2725
2726        real_inputs = [torch.rand_like(x) for _ in range(10)]
2727        real_targets = [torch.rand_like(y) for _ in range(10)]
2728
2729        for m, opt, relu, loss_fn in zip(
2730            (model_graphed, model_control),
2731            (opt_graphed, opt_control),
2732            (relu_graphed, relu_control),
2733            (loss_fn_graphed, loss_fn_control),
2734        ):
2735            # Resets RNC states before iterations for graphed and ungraphed models,
2736            # so dropout math should be bitwise identical for both.
2737            torch.manual_seed(5)
2738            torch.cuda.manual_seed(5)
2739            for data, target in zip(real_inputs, real_targets):
2740                opt.zero_grad(set_to_none=True)
2741                with torch.amp.autocast(
2742                    device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2743                ):
2744                    y_pred = m({"x": data, "unused_input": unused_input})["output"]
2745                    y_pred = relu(y_pred)
2746                    loss = loss_fn(y_pred, target)
2747                    loss.backward()
2748                opt.step()
2749
2750        for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
2751            self.assertEqual(p, pc)
2752
2753        # We graphed the models in training mode. Eval should still run ungraphed.
2754        model_graphed.eval()
2755        model_control.eval()
2756        self.assertEqual(
2757            model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
2758        )
2759
2760    @unittest.skipIf(
2761        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2762    )
2763    @parametrize(
2764        "with_amp,cache_enabled,allow_unused_input",
2765        [
2766            subtest((False, False, True), decorators=[skipIfRocm]),
2767            subtest((True, False, True), decorators=[skipIfRocm]),
2768            subtest((True, True, True), decorators=[unittest.expectedFailure]),
2769            subtest((False, False, False), decorators=[skipIfRocm]),
2770        ],
2771        name_fn=lambda x, y, z: "{}{}{}".format(
2772            {True: "with_amp", False: "without_amp"}[x],
2773            {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
2774            {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
2775        ),
2776    )
2777    @serialTest()
2778    def test_graph_make_graphed_callables_parameterless_nograd_module(
2779        self, with_amp, cache_enabled, allow_unused_input
2780    ):
2781        torch.manual_seed(5)
2782        torch.cuda.manual_seed(5)
2783
2784        N, D_in, H, D_out = 640, 4096, 2048, 1024
2785
2786        class ParameterlessModule(torch.nn.Module):
2787            def forward(self, input_dict: dict):
2788                x = input_dict["x"]
2789                idx = (
2790                    torch.arange(x.size(0), device=x.device)
2791                    .view(-1, 1)
2792                    .repeat(1, x.size(1))
2793                )
2794                return {"output": torch.gather(x, 0, idx)}
2795
2796        models = []
2797        for _ in range(2):
2798            model_section1 = ParameterlessModule().cuda()
2799            models.append(torch.nn.Sequential(model_section1))
2800
2801        model_graphed = models[0]
2802        model_control = models[1]
2803
2804        model_graphed.load_state_dict(model_control.state_dict())
2805
2806        x = torch.randn(N, D_in, device="cuda", requires_grad=False)
2807        unused_input = torch.randn(N, H, device="cuda", requires_grad=False)
2808        y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False)
2809        y = torch.randn(N, D_in, device="cuda")
2810
2811        # This is a good stress test. It graphs four callables: two Modules and two python functions.
2812        with torch.amp.autocast(
2813            device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2814        ):
2815            model_graphed[0] = torch.cuda.make_graphed_callables(
2816                model_graphed[0],
2817                ({"x": x, "unused_input": unused_input},),
2818                allow_unused_input=allow_unused_input,
2819            )
2820
2821        real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)]
2822        real_targets = [torch.rand_like(y) for _ in range(10)]
2823
2824        for m in (model_graphed, model_control):
2825            # Resets RNC states before iterations for graphed and ungraphed models,
2826            # so dropout math should be bitwise identical for both.
2827            torch.manual_seed(5)
2828            torch.cuda.manual_seed(5)
2829            for data, _ in zip(real_inputs, real_targets):
2830                with torch.amp.autocast(
2831                    device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
2832                ):
2833                    out = m({"x": data, "unused_input": unused_input})["output"]
2834
2835        # We graphed the models in training mode. Eval should still run ungraphed.
2836        model_graphed.eval()
2837        model_control.eval()
2838        self.assertEqual(
2839            model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
2840        )
2841
2842    @unittest.skipIf(
2843        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2844    )
2845    def test_graph_make_graphed_callables_same_pool(self):
2846        torch.manual_seed(5)
2847        torch.cuda.manual_seed(5)
2848        models = []
2849        num_models = 3
2850        for _ in range(num_models):
2851            models.append(
2852                torch.nn.Sequential(
2853                    torch.nn.Linear(32, 128),
2854                    torch.nn.ReLU(),
2855                    torch.nn.Linear(128, 128),
2856                ).cuda()
2857            )
2858        # we will reuse the same pool for all graph captures
2859        mempool = torch.cuda.graph_pool_handle()
2860        graphed_models = []
2861        for model in models:
2862            x = torch.randn([64, 32], device="cuda")
2863            graphed_model = deepcopy(model)
2864            graphed_model = torch.cuda.make_graphed_callables(
2865                graphed_model, (x,), pool=mempool
2866            )
2867            graphed_models.append(graphed_model)
2868
2869        for model, graphed_model in zip(models, graphed_models):
2870            x = torch.randn([64, 32], device="cuda")
2871            y = model(x)
2872            yg = graphed_model(x)
2873            l = y.norm()
2874            lg = yg.norm()
2875            l.backward()
2876            lg.backward()
2877
2878            self.assertEqual(y, yg)
2879            self.assertEqual(l, lg)
2880            for p, pg in zip(model.parameters(), graphed_model.parameters()):
2881                self.assertEqual(p, pg)
2882                self.assertEqual(p.grad, pg.grad)
2883                self.assertNotEqual(p.data_ptr(), pg.data_ptr())
2884                self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr())
2885
2886    def _test_graphed_optimizer(
2887        self, steps_warmup, steps_train, optimizer_ctor, kwargs
2888    ):
2889        for actually_do_graphs in (True, False):
2890            params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [
2891                torch.randn((), device="cuda")
2892            ]
2893            params_control = [p.clone().requires_grad_() for p in params]
2894            params_graphed = [p.clone().requires_grad_() for p in params]
2895
2896            grads = [
2897                [torch.randn_like(p) for p in params]
2898                for _ in range(steps_warmup + steps_train)
2899            ]
2900
2901            # Control (capturable=False)
2902
2903            opt = optimizer_ctor(params_control, capturable=False, **kwargs)
2904
2905            for i in range(steps_warmup + steps_train):
2906                for j, p in enumerate(params_control):
2907                    p.grad = grads[i][j]
2908                opt.step()
2909
2910            # capturable=True
2911
2912            opt = optimizer_ctor(params_graphed, capturable=True, **kwargs)
2913
2914            for i in range(steps_warmup):
2915                for j, p in enumerate(params_graphed):
2916                    p.grad = grads[i][j]
2917                opt.step()
2918
2919            if actually_do_graphs:
2920                g = torch.cuda.CUDAGraph()
2921                with torch.cuda.graph(g):
2922                    opt.step()
2923
2924            for i in range(steps_train):
2925                if actually_do_graphs:
2926                    for j, p in enumerate(params_graphed):
2927                        p.grad.copy_(grads[i + steps_warmup][j])
2928                    g.replay()
2929                else:
2930                    # Passing capturable=True to the constructor and running without graphs should still be
2931                    # numerically correct, even if it's not ideal for performance.
2932                    for j, p in enumerate(params_graphed):
2933                        p.grad = grads[i + steps_warmup][j]
2934                    opt.step()
2935
2936            for p_control, p_graphed in zip(params_control, params_graphed):
2937                self.assertEqual(p_control, p_graphed)
2938
2939    @unittest.skipIf(
2940        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2941    )
2942    def test_graph_optims_with_explicitly_capturable_param_groups(self):
2943        # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__
2944        n_warmup, n_replay = 3, 2
2945        for optimizer, second_param_group_capturable in product(
2946            (
2947                torch.optim.Adam,
2948                torch.optim.AdamW,
2949                torch.optim.ASGD,
2950                torch.optim.Adamax,
2951                torch.optim.NAdam,
2952                torch.optim.RAdam,
2953                torch.optim.Adadelta,
2954                torch.optim.RMSprop,
2955                torch.optim.Rprop,
2956            ),
2957            (True, False),
2958        ):
2959            ref_p1, param1 = (
2960                torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
2961            )
2962            ref_p2, param2 = (
2963                torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
2964            )
2965            grads1, grads2 = (
2966                [torch.randn_like(param1) for _ in range(n_warmup + n_replay)]
2967                for _ in range(2)
2968            )
2969            ref_grads1, ref_grads2 = (
2970                [t.clone() for t in tensors] for tensors in (grads1, grads2)
2971            )
2972            params = [
2973                {"params": [param1], "capturable": True},
2974                {"params": [param2], "capturable": second_param_group_capturable},
2975            ]
2976            opt = optimizer(params)
2977            opt_ = optimizer(
2978                [
2979                    {"params": [ref_p1], "capturable": False},
2980                    {"params": [ref_p2], "capturable": False},
2981                ]
2982            )
2983
2984            for i in range(n_warmup + n_replay):
2985                ref_p1.grad = ref_grads1[i]
2986                ref_p2.grad = ref_grads2[i]
2987                opt_.step()
2988
2989            for i in range(n_warmup):
2990                param1.grad = grads1[i]
2991                param2.grad = grads2[i]
2992                opt.step()
2993
2994            g = torch.cuda.CUDAGraph()
2995            if not second_param_group_capturable:
2996                with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"):
2997                    with torch.cuda.graph(g):
2998                        opt.step()
2999            else:
3000                with torch.cuda.graph(g):
3001                    opt.step()
3002
3003                for i in range(n_replay):
3004                    param1.grad.copy_(grads1[n_warmup + i])
3005                    param2.grad.copy_(grads2[n_warmup + i])
3006                    g.replay()
3007                self.assertEqual(ref_p1, param1)
3008                self.assertEqual(ref_p2, param2)
3009
3010    @unittest.skipIf(
3011        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3012    )
3013    def test_cuda_graph_error_options(self):
3014        def fn():
3015            x = torch.zeros([2000], device="cuda")
3016            y = x + x + x
3017            return y
3018
3019        mem = None
3020
3021        def raw_malloc():
3022            global mem
3023            mem = None
3024            stream = torch.cuda.Stream()
3025            try:
3026                with torch.cuda.stream(stream):
3027                    mem = torch.cuda.caching_allocator_alloc(1024)
3028            except BaseException:
3029                if mem is None:
3030                    return
3031            try:
3032                torch.cuda.caching_allocator_delete(mem)
3033                mem = None
3034                return None
3035            except BaseException:
3036                pass
3037
3038        def throws_on_cuda_event(capture_error_mode):
3039            graph = torch.cuda.CUDAGraph()
3040            torch.cuda.synchronize()
3041            stream = torch.cuda.Stream()
3042            stream.wait_stream(torch.cuda.current_stream())
3043            with torch.cuda.stream(stream):
3044                fn()
3045            stream.synchronize()
3046            torch.cuda.current_stream().wait_stream(stream)
3047            torch.cuda.synchronize()
3048            try:
3049                with torch.cuda.graph(
3050                    graph, stream=stream, capture_error_mode=capture_error_mode
3051                ):
3052                    out = fn()
3053                    thread = threading.Thread(target=raw_malloc)
3054                    thread.start()
3055                    thread.join()
3056            except Exception:
3057                if mem is not None:
3058                    torch.cuda.caching_allocator_delete(mem)
3059                return True
3060
3061            return False
3062
3063        self.assertFalse(throws_on_cuda_event("thread_local"))
3064        self.assertFalse(throws_on_cuda_event("relaxed"))
3065
3066        # Exception would Corrupt Process and make other tests fail
3067        # self.assertTrue(throws_on_cuda_event("global"))
3068
3069    @unittest.skipIf(
3070        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3071    )
3072    def test_cuda_graph_allocator_propagates_stream(self):
3073        segments = torch.cuda.memory_snapshot()
3074        existing_pools = {s["segment_pool_id"] for s in segments}
3075        x = torch.randn(10240000, device="cuda")
3076        y = torch.rand_like(x)
3077        g = torch.cuda.CUDAGraph()
3078        s0 = torch.cuda.Stream()
3079        s1 = torch.cuda.Stream()
3080        s0.wait_stream(torch.cuda.current_stream())
3081        with torch.cuda.stream(s0):
3082            g.capture_begin()
3083            z = x + y
3084        with torch.cuda.stream(s1):
3085            s1.wait_stream(s0)
3086            w = z + y
3087        s0.wait_stream(s1)
3088        with torch.cuda.stream(s0):
3089            g.capture_end()
3090        segments = torch.cuda.memory_snapshot()
3091        x = [
3092            s["segment_pool_id"]
3093            for s in segments
3094            if s["segment_pool_id"] not in existing_pools
3095        ]
3096        self.assertEqual(len(x), 2)
3097        self.assertEqual(x[0], x[1])
3098
3099    def test_batch_norm_gather_stats(self):
3100        input = torch.randn(1, 3, 3, 3, device="cuda")
3101        mean, invstd = torch.batch_norm_gather_stats(
3102            input,
3103            mean=torch.ones(2, 3, device="cuda"),
3104            invstd=torch.ones(2, 3, device="cuda"),
3105            running_mean=None,
3106            running_var=None,
3107            momentum=0.1,
3108            eps=1e-5,
3109            count=2,
3110        )
3111        self.assertEqual(mean, torch.ones(3, device="cuda"))
3112        self.assertEqual(invstd, torch.ones(3, device="cuda"))
3113
3114    def test_matmul_memory_use(self):
3115        def get_max_used():
3116            torch.cuda.synchronize()
3117            val = torch.cuda.max_memory_allocated()
3118            torch.cuda.reset_peak_memory_stats()
3119            return val
3120
3121        a = torch.rand(1, 32, 32, device="cuda")
3122        b = torch.rand(24, 32, 1, device="cuda")
3123
3124        get_max_used()
3125
3126        torch.matmul(a, b)
3127
3128        matmul_mem = get_max_used()
3129
3130        a = a.expand(24, 32, 32)
3131        torch.matmul(a, b)
3132
3133        matmul_expand_mem = get_max_used()
3134
3135        torch.bmm(a, b)
3136
3137        bmm_mem = get_max_used()
3138
3139        self.assertEqual(matmul_expand_mem, matmul_mem)
3140        self.assertEqual(bmm_mem, matmul_mem)
3141
3142    @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test")
3143    def test_rocm_backward_pass_guard(self):
3144        # The test exercises a ROCm-specific feature.
3145
3146        class MyFunction(torch.autograd.Function):
3147            @staticmethod
3148            def forward(ctx, tensor, constant):
3149                self.assertFalse(torch._C._rocm_is_backward_pass())
3150                ctx.constant = constant
3151                return tensor * constant
3152
3153            @staticmethod
3154            def backward(ctx, grad_output):
3155                self.assertTrue(torch._C._rocm_is_backward_pass())
3156                return grad_output * ctx.constant, None
3157
3158        class MyModule(torch.nn.Module):
3159            def __init__(self) -> None:
3160                super().__init__()
3161                self.a = torch.nn.Parameter(torch.randn(()))
3162
3163            def forward(self, x):
3164                return MyFunction.apply(x, self.a)
3165
3166        model = MyModule()
3167        criterion = torch.nn.MSELoss(reduction="sum")
3168        optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
3169
3170        x = torch.randn(5, 5)
3171        result = model(x)
3172        loss = criterion(result, x)
3173        optimizer.zero_grad()
3174        loss.backward()
3175        optimizer.step()
3176
3177    def test_matmul_device_mismatch(self):
3178        cpu = torch.rand((10, 10))
3179        cuda = cpu.cuda()
3180        with self.assertRaisesRegex(
3181            RuntimeError, "Expected all tensors to be on the same device"
3182        ):
3183            cpu @ cuda
3184        with self.assertRaisesRegex(
3185            RuntimeError, "Expected all tensors to be on the same device"
3186        ):
3187            cuda @ cpu
3188
3189        for s, m1, m2 in product((cpu, cuda), repeat=3):
3190            if s.device == m1.device == m2.device:
3191                torch.addmm(s, m1, m2)
3192            else:
3193                with self.assertRaisesRegex(
3194                    RuntimeError, "Expected all tensors to be on the same device"
3195                ):
3196                    torch.addmm(s, m1, m2)
3197
3198    @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient")
3199    def test_lazy_init(self):
3200        """Validate that no CUDA calls are made during `import torch` call"""
3201
3202        def check_output(script: str) -> str:
3203            return (
3204                subprocess.check_output([sys.executable, "-c", script])
3205                .decode("ascii")
3206                .strip()
3207            )
3208
3209        VISIBLE_DEVICES = (
3210            "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES"
3211        )
3212        test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())"
3213        rc = check_output(test_script)
3214        self.assertEqual(rc, "0")
3215        if not TEST_WITH_ROCM:
3216            # Check that `cuInit` was not called during the import
3217            # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3
3218            # See https://github.com/pytorch/pytorch/issues/116276 for more details
3219            libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll"
3220            cuda_driver_api_call = (
3221                f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))"
3222            )
3223            rc = check_output(
3224                f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})"
3225            )
3226            self.assertEqual(rc, "3")
3227
3228    @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
3229    def test_hip_device_count(self):
3230        """Validate device_count works with both CUDA/HIP visible devices"""
3231        test_script = """\
3232import torch
3233import os
3234print(f"{torch.cuda.device_count()}")
3235"""
3236        custom_envs = [
3237            {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
3238            {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
3239            {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
3240        ]
3241
3242        for env_config in custom_envs:
3243            env = os.environ.copy()
3244            for key, value in env_config.items():
3245                if value is None:
3246                    env.pop(key, None)
3247                else:
3248                    env[key] = value
3249            r = (
3250                subprocess.check_output([sys.executable, "-c", test_script], env=env)
3251                .decode("ascii")
3252                .strip()
3253            )
3254            self.assertEqual("1", r)
3255
3256    @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices")
3257    def test_device_count_not_cached_pre_init(self):
3258        visible_devices = (
3259            "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES"
3260        )
3261        test_script = f"""\
3262import torch
3263import os
3264r1 = torch.cuda.device_count()
3265os.environ['{visible_devices}'] = '0'
3266r2 = torch.cuda.device_count()
3267torch.empty(10, device='cuda')
3268print(f"{{r1}}, {{r2}}")
3269"""
3270
3271        r = (
3272            subprocess.check_output([sys.executable, "-c", test_script])
3273            .decode("ascii")
3274            .strip()
3275        )
3276
3277        x = torch.cuda.device_count()
3278        self.assertEqual(f"{x}, 1", r)
3279
3280    @unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
3281    def test_gds_fails_in_ci(self):
3282        if IS_WINDOWS or TEST_WITH_ROCM:
3283            error_msg = "is not supported on this platform"
3284        else:
3285            error_msg = "cuFileHandleRegister failed"
3286        with TemporaryFileName() as f:
3287            with self.assertRaisesRegex(RuntimeError, error_msg):
3288                file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
3289
3290
3291@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
3292@torch.testing._internal.common_utils.markDynamoStrictTest
3293class TestCudaMallocAsync(TestCase):
3294    @unittest.skipIf(
3295        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3296    )
3297    def test_memory_snapshot(self):
3298        try:
3299            torch.cuda.memory.empty_cache()
3300            torch.cuda.memory._record_memory_history("state", stacks="python")
3301            # make x the second block in a segment
3302            torch.rand(2 * 311, 411, device="cuda")
3303            unused = torch.rand(310, 410, device="cuda")
3304            x = torch.rand(311, 411, device="cuda")
3305
3306            # create a bunch of tensors that all will tile into the
3307            # same segment to  exercise the history merging code
3308            # 512B is the minimum block size,
3309            # so we allocate all the tensors to this size to make sure
3310            # they tile evenly
3311            tensors = [torch.rand(128, device="cuda") for _ in range(1000)]
3312            while tensors:
3313                del tensors[randint(0, len(tensors) - 1)]
3314
3315            # exercise the history trimming code
3316            torch.rand(128 * 5, device="cuda")
3317
3318            ss = torch.cuda.memory._snapshot()
3319            found_it = False
3320            for seg in ss["segments"]:
3321                self.assertTrue("frames" in seg)
3322                for b in seg["blocks"]:
3323                    if b["requested_size"] == 311 * 411 * 4:
3324                        self.assertTrue("test_cuda" in b["frames"][0]["filename"])
3325                        found_it = True
3326                        self.assertEqual(x.untyped_storage().data_ptr(), b["address"])
3327            self.assertTrue(found_it)
3328
3329            if not IS_WINDOWS:
3330                with tempfile.NamedTemporaryFile() as f:
3331                    torch.cuda.memory._save_segment_usage(f.name)
3332                    with open(f.name) as f2:
3333                        self.assertTrue("test_cuda.py" in f2.read())
3334            del unused
3335            del x
3336            torch.cuda.empty_cache()
3337            ss = torch.cuda.memory._snapshot()
3338            self.assertTrue(
3339                ss["device_traces"][0][-1]["action"]
3340                in ("segment_free", "segment_unmap")
3341            )
3342
3343        finally:
3344            torch.cuda.memory._record_memory_history(None)
3345
3346    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
3347    def test_direct_traceback(self):
3348        from torch._C._profiler import gather_traceback, symbolize_tracebacks
3349
3350        c = gather_traceback(True, True, True)
3351        (r,) = symbolize_tracebacks([c])
3352        r = str(r)
3353        self.assertTrue("test_cuda.py" in r)
3354        self.assertTrue("unwind" in r)
3355
3356    @unittest.skipIf(
3357        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3358    )
3359    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3360    def test_memory_snapshot_with_cpp(self):
3361        try:
3362            torch.cuda.memory.empty_cache()
3363            torch.cuda.memory._record_memory_history("state", stacks="all")
3364            x = torch.rand(311, 411, device="cuda")
3365
3366            ss = torch.cuda.memory._snapshot()["segments"]
3367            found_it = False
3368            for seg in ss:
3369                for b in seg["blocks"]:
3370                    if b["requested_size"] == 311 * 411 * 4:
3371                        self.assertTrue("::rand" in str(b["frames"]))
3372                        found_it = True
3373            self.assertTrue(found_it)
3374
3375        finally:
3376            torch.cuda.memory._record_memory_history(None)
3377
3378    @skipIfRocm
3379    def test_memory_profiler_viz(self):
3380        with torch.profiler.profile(
3381            with_stack=True, profile_memory=True, record_shapes=True
3382        ) as prof:
3383            x = torch.rand(128, 128, device="cuda")
3384            x * x + x * x
3385        plot = profile_plot(prof)
3386        plot = json.dumps(_profile_to_snapshot(prof))
3387        self.assertTrue("test_cuda.py" in plot)
3388        self.assertTrue("test_memory_profiler_viz" in plot)
3389        self.assertTrue("category" in plot)
3390
3391    @unittest.skipIf(
3392        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3393    )
3394    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3395    def test_cycles(self):
3396        fired = False
3397
3398        def observer(html):
3399            nonlocal fired
3400            fired = True
3401            self.assertTrue("torch.Tensor" in html)
3402            self.assertTrue("test_cuda" in html)
3403            self.assertTrue("cell_contents" in html)
3404
3405        disarm = observe_tensor_cycles(observer)
3406
3407        def noop():
3408            pass
3409
3410        try:
3411
3412            def create():
3413                x = torch.empty(3, 4, device="cuda")
3414
3415                def foo(p):
3416                    if p:
3417                        return foo(not p)
3418                    else:
3419                        return x
3420
3421                return foo
3422
3423            create()
3424            gc.collect()
3425            # the callback has to run outside of the collect
3426            # call so it doesn't actual fire until the next
3427            # method call after a gc.collect
3428            noop()
3429            self.assertTrue(fired)
3430        finally:
3431            disarm()
3432
3433    @unittest.skipIf(
3434        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3435    )
3436    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3437    def test_memory_plots(self):
3438        for context, stacks in (
3439            ("all", "all" if IS_LINUX else "python"),
3440            ("all", "python"),
3441            (None, "python"),
3442        ):
3443            try:
3444                torch.cuda.memory.empty_cache()
3445                torch.cuda.memory._record_memory_history(
3446                    "all", context=context, stacks=stacks
3447                )
3448
3449                def run():
3450                    x = torch.rand(128, 128, device="cuda")
3451                    x * x + x * x
3452
3453                run()
3454                cpp = stacks == "all"
3455                record_context = context is not None
3456                ss = torch.cuda.memory._snapshot()
3457
3458                tplot = trace_plot(ss)
3459                splot = segment_plot(ss)
3460                text = json.dumps(ss)
3461
3462                self.assertTrue(record_context == ("test_memory_plots" in text))
3463                self.assertTrue(cpp == ("::rand" in text))
3464                self.assertTrue(str(128 * 128 * 4) in text)
3465
3466            finally:
3467                torch.cuda.memory._record_memory_history(None)
3468
3469    @unittest.skipIf(
3470        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3471    )
3472    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3473    def test_memory_plots_free_stack(self):
3474        for context in ["alloc", "all", "state"]:
3475            try:
3476                torch.cuda.memory.empty_cache()
3477                torch.cuda.memory._record_memory_history(context=context)
3478                x = None
3479
3480                def thealloc():
3481                    nonlocal x
3482                    x = torch.rand(3, 4, device="cuda")
3483
3484                def thefree():
3485                    nonlocal x
3486                    del x
3487
3488                thealloc()
3489                thefree()
3490                ss = json.dumps(torch.cuda.memory._snapshot())
3491                self.assertTrue(("thefree" in ss) == (context == "all"))
3492                self.assertTrue(("thealloc" in ss) == (context != "state"))
3493            finally:
3494                torch.cuda.memory._record_memory_history(None)
3495
3496    @unittest.skipIf(
3497        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3498    )
3499    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3500    def test_memory_plots_history_context(self):
3501        try:
3502            torch.cuda.memory.empty_cache()
3503            x = None
3504
3505            def should_capture1():
3506                nonlocal x
3507                x = torch.rand(4, 4, device="cuda")
3508
3509            def should_not_capture():
3510                nonlocal x
3511                x = torch.rand(3, 4, device="cuda")
3512
3513            def should_capture2():
3514                nonlocal x
3515                x = torch.rand(4, 4, device="cuda")
3516
3517            # Recording with context and python call stacks should capture the call stack.
3518            torch.cuda.memory._record_memory_history(context="all", stacks="python")
3519            should_capture1()
3520            # Recording with context=None should not capture the call stack.
3521            torch.cuda.memory._record_memory_history(context=None)
3522            should_not_capture()
3523            # Recording with context and python call stacks should capture the call stack.
3524            torch.cuda.memory._record_memory_history(context="all", stacks="python")
3525            should_capture2()
3526
3527            ss = json.dumps(torch.cuda.memory._snapshot())
3528            self.assertTrue("should_capture1" in ss)
3529            self.assertTrue("should_not_capture" not in ss)
3530            self.assertTrue("should_capture2" in ss)
3531        finally:
3532            torch.cuda.memory._record_memory_history(None)
3533
3534    @unittest.skipIf(
3535        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3536    )
3537    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3538    def test_memory_plots_free_segment_stack(self):
3539        for context in ["alloc", "all", "state"]:
3540            try:
3541                torch.cuda.memory.empty_cache()
3542                torch.cuda.memory._record_memory_history(context=context)
3543                x = torch.rand(3, 4, device="cuda")
3544                del x
3545                torch.cuda.memory.empty_cache()
3546
3547                ss = json.dumps(torch.cuda.memory._snapshot())
3548                self.assertTrue(("empty_cache" in ss) == (context == "all"))
3549            finally:
3550                torch.cuda.memory._record_memory_history(None)
3551
3552    @unittest.skipIf(
3553        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3554    )
3555    def test_memory_snapshot_script(self):
3556        try:
3557            torch.cuda.memory.empty_cache()
3558            torch.cuda.memory._record_memory_history("state", stacks="python")
3559
3560            @torch.jit.script
3561            def foo():
3562                return torch.rand(311, 411, device="cuda")
3563
3564            x = foo()
3565
3566            ss = torch.cuda.memory._snapshot()["segments"]
3567            found_it = False
3568            for seg in ss:
3569                for b in seg["blocks"]:
3570                    if b["requested_size"] == 311 * 411 * 4:
3571                        self.assertTrue(b["frames"][0]["name"] == "foo")
3572                        found_it = True
3573            self.assertTrue(found_it)
3574
3575        finally:
3576            torch.cuda.memory._record_memory_history(None)
3577
3578    def test_max_split_expandable(self):
3579        torch.cuda.memory.empty_cache()
3580        mb = 1024 * 1024
3581        _, all_memory = torch.cuda.memory.mem_get_info()
3582        total_allowed = 120 * mb
3583        fraction_allowed = total_allowed / all_memory
3584        assert int(fraction_allowed * all_memory) == total_allowed
3585        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
3586
3587        def alloc(n):
3588            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
3589
3590        torch.cuda.memory._set_allocator_settings(
3591            "expandable_segments:False,max_split_size_mb:40"
3592        )
3593        a = alloc(40)
3594        torch.cuda.memory._set_allocator_settings(
3595            "expandable_segments:True,max_split_size_mb:40"
3596        )
3597        b = alloc(40)
3598        torch.cuda.memory._set_allocator_settings(
3599            "expandable_segments:False,max_split_size_mb:40"
3600        )
3601        c = alloc(40)
3602        with self.assertRaises(torch.OutOfMemoryError):
3603            alloc(40)
3604        del a, b, c
3605        # force release_cached_blocks to run with some expandable segments in the free list
3606        alloc(120)
3607
3608    def test_garbage_collect_expandable(self):
3609        torch.cuda.memory.empty_cache()
3610        mb = 1024 * 1024
3611        _, all_memory = torch.cuda.memory.mem_get_info()
3612        total_allowed = 120 * mb
3613        fraction_allowed = total_allowed / all_memory
3614        assert int(fraction_allowed * all_memory) == total_allowed
3615        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
3616
3617        def alloc(n):
3618            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
3619
3620        torch.cuda.memory._set_allocator_settings(
3621            "expandable_segments:False,garbage_collection_threshold:0.5"
3622        )
3623        a = alloc(40)
3624        torch.cuda.memory._set_allocator_settings(
3625            "expandable_segments:True,garbage_collection_threshold:0.5"
3626        )
3627        b = alloc(40)
3628        del a, b
3629        # causes GC to run. The expandable segment block will be split
3630        # so GC would not attempt to free it anyway, but this at least makes sure
3631        # expandable_segment blocks can be in the free list when this is called.
3632        alloc(80)
3633
3634    def test_allocator_settings(self):
3635        def power2_div(size, div_factor):
3636            pow2 = 1
3637            while pow2 < size:
3638                pow2 = pow2 * 2
3639            if pow2 == size:
3640                return pow2
3641            step = pow2 / 2 / div_factor
3642            ret = pow2 / 2
3643            while ret < size:
3644                ret = ret + step
3645            return ret
3646
3647        torch.cuda.memory.empty_cache()
3648        key_allocated = (
3649            "active_bytes.all.allocated"
3650            if not TEST_CUDAMALLOCASYNC
3651            else "allocated_bytes.all.current"
3652        )
3653        key_requested = "requested_bytes.all.allocated"
3654
3655        nelems = 21 * 1024 * 1024
3656        nbytes = 4 * nelems  # floats are 4 bytes
3657
3658        nelems_big = 100 * 1024 * 1024
3659        nbytes_big = 4 * nelems_big  # floats are 4 bytes
3660
3661        start_mem = torch.cuda.memory_stats()[key_allocated]
3662        torch.cuda.memory._set_allocator_settings("")
3663        x = torch.rand(nelems, device="cuda")
3664
3665        # test roundup_power2_divisions single value syntax
3666        reg_mem = torch.cuda.memory_stats()[key_allocated]
3667        start_requested = torch.cuda.memory_stats()[key_requested]
3668        torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4")
3669        y = torch.rand(nelems, device="cuda")
3670
3671        pow2_div4_mem = torch.cuda.memory_stats()[key_allocated]
3672        current_requested = torch.cuda.memory_stats()[key_requested]
3673
3674        self.assertTrue(reg_mem - start_mem == nbytes)
3675        if not TEST_CUDAMALLOCASYNC:
3676            # not supported with the cudaMallocAsync backend
3677            self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
3678            self.assertTrue(current_requested - start_requested == nbytes)
3679
3680        torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
3681        torch.cuda.memory._set_allocator_settings(
3682            "garbage_collection_threshold:0.5,max_split_size_mb:40"
3683        )
3684
3685        # should have reset the power2 divisions now
3686        torch.cuda.memory.empty_cache()
3687        start_mem = torch.cuda.memory_stats()[key_allocated]
3688        z = torch.rand(nelems, device="cuda")
3689        reg_mem = torch.cuda.memory_stats()[key_allocated]
3690        self.assertTrue(reg_mem - start_mem == nbytes)
3691
3692        # roundup_power2_divisions knob array syntax
3693        torch.cuda.memory.empty_cache()
3694        torch.cuda.memory._set_allocator_settings(
3695            "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]"
3696        )
3697        start_mem = torch.cuda.memory_stats()[key_allocated]
3698        w = torch.rand(nelems, device="cuda")
3699
3700        pow2_div8_mem = torch.cuda.memory_stats()[key_allocated]
3701        if not TEST_CUDAMALLOCASYNC:
3702            # not supported with the cudaMallocAsync backend
3703            self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8))
3704
3705        torch.cuda.memory.empty_cache()
3706        start_mem = torch.cuda.memory_stats()[key_allocated]
3707        v = torch.rand(nelems_big, device="cuda")
3708
3709        pow2_div2_mem = torch.cuda.memory_stats()[key_allocated]
3710        if not TEST_CUDAMALLOCASYNC:
3711            # not supported with the cudaMallocAsync backend
3712            self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2))
3713
3714        torch.cuda.memory.empty_cache()
3715        torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True")
3716        start_mem = torch.cuda.memory_stats()[key_allocated]
3717        w = torch.rand(nelems, device="cuda")
3718        reg_mem = torch.cuda.memory_stats()[key_allocated]
3719        self.assertTrue(reg_mem - start_mem == nbytes)
3720
3721        with self.assertRaises(RuntimeError):
3722            torch.cuda.memory._set_allocator_settings("foo:1,bar:2")
3723
3724        with self.assertRaises(RuntimeError):
3725            torch.cuda.memory._set_allocator_settings(
3726                "garbage_collection_threshold:1.2"
3727            )
3728
3729        with self.assertRaises(RuntimeError):
3730            torch.cuda.memory._set_allocator_settings("max_split_size_mb:2")
3731
3732        with self.assertRaises(RuntimeError):
3733            torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none")
3734
3735        with self.assertRaises(RuntimeError):
3736            torch.cuda.memory._set_allocator_settings(
3737                "pinned_use_cuda_host_register:none"
3738            )
3739
3740        with self.assertRaises(RuntimeError):
3741            torch.cuda.memory._set_allocator_settings(
3742                "pinned_num_register_threads:none"
3743            )
3744
3745        with self.assertRaises(RuntimeError):
3746            torch.cuda.memory._set_allocator_settings(
3747                "pinned_num_register_threads:1024"
3748            )
3749
3750    @parametrize("max_split_size_mb_setting", [False, True])
3751    def test_raises_oom(self, max_split_size_mb_setting):
3752        if max_split_size_mb_setting:
3753            # CudaCachingAllocator does early return when searching available blocks
3754            # if max_split_size_mb is not set
3755            # Setting this triggers more parts of the code
3756            torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024")
3757            torch.cuda.memory.empty_cache()
3758        with self.assertRaises(torch.cuda.OutOfMemoryError):
3759            torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
3760
3761    @unittest.skipIf(
3762        not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux"
3763    )
3764    @unittest.skipIf(
3765        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3766    )
3767    def test_cpp_memory_snapshot_pickle(self):
3768        from torch.utils.cpp_extension import load_inline
3769
3770        source = """
3771        #include <torch/csrc/cuda/memory_snapshot.h>
3772        py::object do_snapshot() {
3773            std::string data = torch::cuda::_memory_snapshot_pickled();
3774            return py::bytes(data);
3775        }
3776        void record(bool e, bool ctx) {
3777            torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx);
3778        }
3779        """
3780        m = load_inline(
3781            name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"]
3782        )
3783        for ctx in (False, True):
3784            try:
3785                m.record(True, ctx)
3786
3787                @torch.jit.script
3788                def the_script_fn():
3789                    return torch.rand(311, 411, device="cuda")
3790
3791                def run():
3792                    t = the_script_fn()
3793                    return pickle.loads(m.do_snapshot())
3794
3795                mem = run()
3796                found = False
3797                for s in mem["segments"]:
3798                    for b in s["blocks"]:
3799                        if b["state"] == "active_allocated":
3800                            if b["requested_size"] == 311 * 411 * 4:
3801                                if ctx:
3802                                    frame_text = str(b["frames"])
3803                                    # C++ frame
3804                                    self.assertTrue("::rand" in frame_text)
3805                                    # script frame
3806                                    self.assertTrue("the_script_fn" in frame_text)
3807                                    # python frame
3808                                    self.assertTrue("case.py" in frame_text)
3809                                found = True
3810                last_action = mem["device_traces"][0][-1]
3811                self.assertTrue(last_action["action"] == "alloc")
3812                self.assertTrue(last_action["size"] == 311 * 411 * 4)
3813                self.assertTrue(found)
3814            finally:
3815                m.record(False, False)
3816
3817    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
3818    def test_notifies_oom(self):
3819        x = False
3820
3821        def cb(device, alloc, device_alloc, device_free):
3822            nonlocal x
3823            x = True
3824
3825        torch._C._cuda_attach_out_of_memory_observer(cb)
3826        with self.assertRaises(torch.cuda.OutOfMemoryError):
3827            torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
3828        self.assertTrue(x)
3829
3830    def test_allocator_fuzz(self):
3831        # fuzz
3832        state = random.getstate()
3833        random.seed(123)
3834        N = 10000
3835        try:
3836            mem = []
3837            total = 0
3838            c = 0
3839
3840            def alloc():
3841                nonlocal total, c
3842                b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4)
3843                mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda")))
3844                c += 1
3845                total += b
3846
3847            def free():
3848                nonlocal total
3849                idx = random.randrange(0, len(mem))
3850                v, x = mem.pop(idx)
3851                assert torch.all(v == x)
3852                total -= x.numel()
3853
3854            choices = [alloc, free, torch.cuda.memory.empty_cache]
3855            for i in range(N):
3856                while total >= 1024 * 1024 * 1024 / (4 * 10):
3857                    free()
3858                (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1])
3859                action()
3860        finally:
3861            random.setstate(state)
3862
3863    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3864    def test_nvml_get_handler(self):
3865        if not torch.version.hip:
3866            self.assertTrue(torch.cuda._get_pynvml_handler() is not None)
3867        else:
3868            self.assertTrue(torch.cuda._get_amdsmi_handler() is not None)
3869
3870    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3871    def test_temperature(self):
3872        self.assertTrue(0 <= torch.cuda.temperature() <= 150)
3873
3874    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3875    def test_power_draw(self):
3876        self.assertTrue(torch.cuda.power_draw() >= 0)
3877
3878    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
3879    def test_clock_speed(self):
3880        self.assertTrue(torch.cuda.clock_rate() >= 0)
3881
3882
3883MIN_BLOCK_SIZE = 512
3884SMALL_SIZE = 1048576
3885SMALL_BUFFER = 2097152
3886LARGE_BUFFER = 20971520
3887
3888
3889def get_cudagraph_segments(pool_id):
3890    segments = torch.cuda.memory_snapshot()
3891    return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
3892
3893
3894def get_all_cudagraph_segments():
3895    segments = torch.cuda.memory_snapshot()
3896    return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
3897
3898
3899def cudagraphify(fn, inputs, pool=None):
3900    if not TEST_CUDA_GRAPH:
3901        raise unittest.SkipTest("cuda graph test is skipped")
3902
3903    torch.cuda.synchronize()
3904    stream = torch.cuda.Stream()
3905    stream.wait_stream(torch.cuda.current_stream())
3906    with torch.cuda.stream(stream):
3907        fn(*inputs)
3908    stream.synchronize()
3909    torch.cuda.current_stream().wait_stream(stream)
3910    torch.cuda.synchronize()
3911
3912    graph = torch.cuda.CUDAGraph()
3913    with torch.cuda.graph(graph, stream=stream, pool=pool):
3914        static_outputs = fn(*inputs)
3915
3916    return graph, static_outputs
3917
3918
3919def int8_cuda(size):
3920    return torch.ones([size], device="cuda", dtype=torch.uint8)
3921
3922
3923def live_blocks(pool_id):
3924    blocks = 0
3925    seg = get_cudagraph_segments(pool_id)
3926    for segment in get_cudagraph_segments(pool_id):
3927        for block in segment["blocks"]:
3928            blocks += block["state"] == "active_allocated"
3929    return blocks
3930
3931
3932def tensor_metadata(x):
3933    return {
3934        "nbytes": x.untyped_storage().nbytes(),
3935        "data_ptr": x.untyped_storage().data_ptr(),
3936        "size": x.shape,
3937        "stride": x.stride(),
3938        "dtype": x.dtype,
3939        "device": x.device,
3940        "storage_offset": x.storage_offset(),
3941    }
3942
3943
3944def reconstruct_from_tensor_metadata(metadata):
3945    s = torch._C._construct_storage_from_data_pointer(
3946        metadata["data_ptr"], metadata["device"], metadata["nbytes"]
3947    )
3948    t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"])
3949    t.set_(
3950        source=s,
3951        storage_offset=metadata["storage_offset"],
3952        size=metadata["size"],
3953        stride=metadata["stride"],
3954    )
3955    return t
3956
3957
3958@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
3959@torch.testing._internal.common_utils.markDynamoStrictTest
3960class TestBlockStateAbsorption(TestCase):
3961    @property
3962    def expandable_segments(self):
3963        return EXPANDABLE_SEGMENTS
3964
3965    def checkCheckpointedBlock(self, before_block, after_block):
3966        for field in ("size", "state"):
3967            self.assertEqual(before_block[field], after_block[field])
3968
3969    def checkCheckpointedState(self, before_segments, after_segments):
3970        # after may contain additional segments, but all of the segments in before
3971        # should be exactly equivalent to after
3972        after_ptr_to_segment = {
3973            segment["address"]: segment for segment in after_segments
3974        }
3975
3976        for before_segment in before_segments:
3977            self.assertTrue(before_segment["address"] in after_ptr_to_segment)
3978            after_segment = after_ptr_to_segment[before_segment["address"]]
3979
3980            for field in (
3981                "device",
3982                "total_size",
3983                "allocated_size",
3984                "active_size",
3985                "segment_type",
3986                "segment_pool_id",
3987            ):
3988                self.assertEqual(before_segment[field], after_segment[field])
3989
3990            self.assertEqual(
3991                len(before_segment["blocks"]), len(after_segment["blocks"])
3992            )
3993            for before_block, after_block in zip(
3994                before_segment["blocks"], after_segment["blocks"]
3995            ):
3996                self.checkCheckpointedBlock(before_block, after_block)
3997
3998    @staticmethod
3999    def setCheckpointPoolState(
4000        device, state, stale_storages_ptr, storages_deleters=None
4001    ):
4002        stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr]
4003        storages_deleters = (
4004            []
4005            if not storages_deleters
4006            else [t.untyped_storage()._cdata for t in storages_deleters]
4007        )
4008        torch._C._cuda_setCheckpointPoolState(
4009            device, state, stale_storages_ptr, storages_deleters
4010        )
4011
4012    def checkFunction(self, fn, inputs, pool=None):
4013        graph, outputs = cudagraphify(fn, inputs, pool=pool)
4014
4015        pool_id = graph.pool()
4016        device = outputs[0].device.index
4017
4018        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4019
4020        state = torch._C._cuda_getCheckpointState(device, pool_id)
4021        self.setCheckpointPoolState(device, state, [], [])
4022
4023        self.checkCheckpointedState(
4024            segments_before_checkpoint, get_cudagraph_segments(pool_id)
4025        )
4026
4027    def setUp(self):
4028        super().setUp()
4029        self.segment_length = len(get_all_cudagraph_segments())
4030
4031    def tearDown(self):
4032        torch.cuda.synchronize()
4033        gc.collect()
4034        torch.cuda.empty_cache()
4035
4036        self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length)
4037
4038        super().tearDown()
4039
4040    def test_simple(self):
4041        def foo():
4042            x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8)
4043            x = x + x
4044            x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE)
4045            y = int8_cuda(SMALL_SIZE) + x1
4046            z = int8_cuda(SMALL_SIZE)
4047            return x, y, z
4048
4049        self.checkFunction(foo, [])
4050
4051    def test_allocated_in_middle_of_segment(self):
4052        def foo():
4053            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4054            return small_buffers[5].add_(2)
4055
4056        self.checkFunction(foo, [])
4057
4058    def test_multiple_middle_allocations(self):
4059        def foo():
4060            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4061            return small_buffers[5], small_buffers[8]
4062
4063        self.checkFunction(foo, [])
4064
4065    def test_middle_allocations_contiguous(self):
4066        def foo():
4067            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4068            return small_buffers[5], small_buffers[6]
4069
4070        self.checkFunction(foo, [])
4071
4072    def test_additional_free_following_checkpoint(self):
4073        def foo():
4074            return (int8_cuda(MIN_BLOCK_SIZE),)
4075
4076        def foo2():
4077            return (int8_cuda(MIN_BLOCK_SIZE),)
4078
4079        graph, outputs = cudagraphify(foo, [])
4080        pool_id = graph.pool()
4081
4082        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4083
4084        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4085
4086        graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4087
4088        self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
4089
4090        del outputs2
4091
4092        self.checkCheckpointedState(
4093            segments_before_checkpoint, get_cudagraph_segments(pool_id)
4094        )
4095
4096    # TODO: re-enable
4097    # def test_additional_free_error(self):
4098    #     def foo():
4099    #         return int8_cuda(MIN_BLOCK_SIZE),
4100
4101    #     def foo2():
4102    #         return int8_cuda(MIN_BLOCK_SIZE),
4103
4104    #     graph, outputs = cudagraphify(foo, [])
4105    #     pool_id = graph.pool()
4106
4107    #     segments_before_checkpoint = get_cudagraph_segments(pool_id)
4108
4109    #     state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4110
4111    # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4112    # with self.assertRaisesRegex(Exception, "being manually freed must be passed"):
4113    #     self.setCheckpointPoolState(outputs[0].device.index, state, [], [])
4114
4115    def test_tensor_dies_after_checkpoint(self):
4116        def foo():
4117            return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE)
4118
4119        graph, outputs = cudagraphify(foo, [])
4120        pool_id = graph.pool()
4121        device = outputs[0].device.index
4122
4123        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4124        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4125
4126        output_data_ptrs = [output.data_ptr() for output in outputs]
4127
4128        del outputs
4129
4130        self.setCheckpointPoolState(device, state, [], [])
4131
4132        self.assertEqual(live_blocks(pool_id), 2)
4133        torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0])
4134        self.assertEqual(live_blocks(pool_id), 1)
4135        torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1])
4136        self.assertEqual(live_blocks(pool_id), 0)
4137
4138    def test_assigning_back_deleter_fns_to_tensor(self):
4139        def foo(x):
4140            return (
4141                int8_cuda(SMALL_BUFFER) + x,
4142                int8_cuda(SMALL_BUFFER) + x,
4143                int8_cuda(LARGE_BUFFER) + x,
4144            )
4145
4146        inp = torch.tensor([1], device="cuda")
4147        graph, outputs = cudagraphify(foo, [inp])
4148        pool_id = graph.pool()
4149        graph.replay()
4150
4151        device = outputs[0].device.index
4152
4153        for i in range(len(outputs)):
4154            self.assertTrue(outputs[i].mean(dtype=torch.float) == 2)
4155
4156        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4157
4158        output_ptrs = [output.untyped_storage().data_ptr() for output in outputs]
4159        ten_metadata = [tensor_metadata(t) for t in outputs]
4160
4161        self.assertEqual(live_blocks(pool_id), 3)
4162
4163        del outputs
4164
4165        self.assertEqual(live_blocks(pool_id), 0)
4166
4167        reconstructed_tensors = [
4168            reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata
4169        ]
4170
4171        for i in range(len(reconstructed_tensors)):
4172            self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2)
4173
4174        inp.add_(1)
4175        graph.replay()
4176
4177        for i in range(len(reconstructed_tensors)):
4178            self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3)
4179
4180        self.setCheckpointPoolState(
4181            device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]]
4182        )
4183
4184        self.assertEqual(live_blocks(pool_id), 3)
4185
4186        reconstructed_tensors[0] = None
4187        self.assertEqual(live_blocks(pool_id), 2)
4188
4189        reconstructed_tensors[1] = None
4190        self.assertEqual(live_blocks(pool_id), 1)
4191
4192        # should not change, we did not pass it in to swap data ptrs
4193        reconstructed_tensors[2] = None
4194        self.assertEqual(live_blocks(pool_id), 1)
4195
4196        torch._C._cuda_cudaCachingAllocator_raw_delete(output_ptrs[2])
4197
4198        self.assertEqual(live_blocks(pool_id), 0)
4199
4200    @skipIfNoTorchVision
4201    def test_resnet(self):
4202        import torchvision
4203
4204        m = torchvision.models.resnet50()
4205        m.eval()
4206        m = m.cuda()
4207
4208        inp = torch.rand([1, 3, 255, 255], device="cuda")
4209        self.checkFunction(m, [inp])
4210
4211    def test_check_pool_live_allocations(self):
4212        def foo():
4213            return torch.ones([4], device="cuda")
4214
4215        pool = torch.cuda.graph_pool_handle()
4216        graph, outputs = cudagraphify(foo, [], pool=pool)
4217
4218        index = outputs[0].device.index
4219
4220        def check(live_dps):
4221            return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps)
4222
4223        self.assertTrue(check({outputs[0].data_ptr()}))
4224
4225        self.assertFalse(check({outputs[0].data_ptr(), 0}))
4226        self.assertFalse(check(set()))
4227
4228        del outputs
4229        self.assertTrue(check(set()))
4230
4231    def test_allocate_in_thread_to_pool(self):
4232        def foo():
4233            return torch.rand([4], device="cuda")
4234
4235        pool = torch.cuda.graph_pool_handle()
4236        graph, outputs = cudagraphify(foo, [], pool=pool)
4237        device = outputs[0].device.index
4238        del outputs
4239
4240        @contextlib.contextmanager
4241        def _use_cuda_memory_pool_manager(device, mem_pool):
4242            """
4243            Context manager to use cuda graph pool for new allocations. If you use this manager
4244            all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
4245            existing_graph should already have been used in a capture, and the mem_pool must already exist.
4246            """
4247            torch.cuda.synchronize()
4248            stream = torch.cuda.Stream()
4249            stream.wait_stream(torch.cuda.current_stream())
4250            stream_context = torch.cuda.stream(stream)
4251            stream_context.__enter__()
4252            torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
4253            try:
4254                yield
4255            finally:
4256                torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
4257                torch._C._cuda_releasePool(device, mem_pool)
4258                stream_context.__exit__(None, None, None)
4259
4260        segments = get_cudagraph_segments(pool)
4261        self.assertEqual(len(get_cudagraph_segments(pool)), 1)
4262
4263        def use_pool():
4264            def alloc_three():
4265                a = int8_cuda(LARGE_BUFFER)
4266                b = int8_cuda(LARGE_BUFFER)
4267                c = a + b
4268
4269            with _use_cuda_memory_pool_manager(device, pool):
4270                # three allocations
4271                for _ in range(10):
4272                    alloc_three()
4273
4274            # three more allocations not in pool
4275            alloc_three()
4276
4277        def no_pool():
4278            # two allocations
4279            for _ in range(10):
4280                a = int8_cuda(LARGE_BUFFER)
4281                b = int8_cuda(LARGE_BUFFER)
4282                del a, b
4283
4284        graph_thread = threading.Thread(target=use_pool)
4285        no_graph_thread = threading.Thread(target=no_pool)
4286        graph_thread.start()
4287        no_graph_thread.start()
4288
4289        graph_thread.join()
4290        no_graph_thread.join()
4291
4292        self.assertEqual(
4293            len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4
4294        )
4295
4296        del graph
4297
4298        torch.cuda.synchronize()
4299        gc.collect()
4300        torch.cuda.empty_cache()
4301
4302        self.assertEqual(len(get_cudagraph_segments(pool)), 0)
4303
4304    def test_no_triton_on_import(self):
4305        """Test that Trition is not imported on first GPU use"""
4306        script = "import sys; import torch; torch.rand(2, device='cuda'); print('triton' in sys.modules)"
4307
4308        rc = (
4309            subprocess.check_output(
4310                [sys.executable, "-c", script],
4311                # On Windows, opening the subprocess with the default CWD makes `import torch`
4312                # fail, so just set CWD to this script's directory
4313                cwd=os.path.dirname(os.path.realpath(__file__)),
4314            )
4315            .strip()
4316            .decode("ascii")
4317        )
4318        self.assertEqual(rc, "False", "Triton was imported when importing torch!")
4319
4320
4321@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4322class TestMemPool(TestCase):
4323    def test_mempool_id(self):
4324        pool1 = torch.cuda.graph_pool_handle()
4325        pool2 = torch.cuda.MemPool().id
4326
4327        # first value of id in a user created pool is always zero
4328        self.assertEqual(pool1[0] == 0, pool2[0] == 0)
4329
4330        # each call to torch.cuda.graph_pool_handle() or torch.cuda.MemPool()
4331        # increments the id
4332        self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
4333
4334    def test_mempool_with_allocator(self):
4335        pool = torch.cuda.MemPool()
4336
4337        # MemPool doesn't have an allocator by default
4338        self.assertEqual(pool.allocator, None)
4339
4340        from torch.utils.cpp_extension import load_inline
4341
4342        dummy_allocator_source = """
4343        #include <torch/extension.h>
4344        #include <ATen/cuda/Exceptions.h>
4345        #include <cuda_runtime_api.h>
4346
4347        extern "C" {
4348          C10_EXPORT int called_dummy_alloc = 0;
4349          C10_EXPORT int called_dummy_free = 0;
4350
4351          // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
4352          C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
4353            called_dummy_alloc = 123;
4354            void* ptr;
4355            C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
4356            return ptr;
4357          }
4358
4359          C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
4360            called_dummy_free = 321;
4361            C10_CUDA_CHECK(cudaFree(ptr));
4362          }
4363        }
4364        """
4365        dummy_allocator_libname = "dummy_allocator"
4366        dummy_allocator = load_inline(
4367            name=dummy_allocator_libname,
4368            cpp_sources=dummy_allocator_source,
4369            is_python_module=False,
4370            keep_intermediates=False,
4371            verbose=True,
4372            with_cuda=True,
4373        )
4374        allocator = torch.cuda.memory.CUDAPluggableAllocator(
4375            dummy_allocator,
4376            "dummy_alloc",
4377            "dummy_free",
4378        )
4379        pool = torch.cuda.MemPool(allocator.allocator())
4380
4381        # pool should point to the same allocator as the one passed into it
4382        self.assertEqual(allocator.allocator(), pool.allocator)
4383
4384        # no allocations happened yet, so called_dummy_alloc should be 0
4385        alloc_lib = ctypes.CDLL(dummy_allocator)
4386        called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
4387        self.assertEqual(called_dummy_alloc.value, 0)
4388
4389        with torch.cuda.use_mem_pool(pool):
4390            out = torch.randn(1, device="cuda")
4391
4392        # called_dummy_alloc should be 123 if dummy_alloc was used to allocate
4393        # out tensor
4394        self.assertEqual(called_dummy_alloc.value, 123)
4395
4396    def test_mempool_context(self):
4397        active_pool = torch.cuda.MemPoolContext.active_pool()
4398
4399        # there is no active pool if none was made active
4400        self.assertEqual(active_pool, None)
4401
4402        pool = torch.cuda.MemPool()
4403        ctx = torch.cuda.MemPoolContext(pool)
4404        active_pool = torch.cuda.MemPoolContext.active_pool()
4405
4406        # pool was made active
4407        self.assertEqual(active_pool, pool)
4408
4409        del ctx
4410        active_pool = torch.cuda.MemPoolContext.active_pool()
4411
4412        # ctx was deleted, so active pool is the previous one
4413        self.assertEqual(active_pool, None)
4414
4415    def test_mempool_multithread(self):
4416        pool_ids = []
4417        active_pool_ids = []
4418
4419        def create_mempool_and_make_active():
4420            pool = torch.cuda.MemPool()
4421            pool_ids.extend([pool.id])
4422
4423            ctx = torch.cuda.MemPoolContext(pool)
4424            active_pool = torch.cuda.MemPoolContext.active_pool()
4425            active_pool_ids.extend([active_pool.id])
4426            del ctx
4427
4428        num_threads = 4
4429        threads = [
4430            threading.Thread(target=create_mempool_and_make_active)
4431            for t in range(num_threads)
4432        ]
4433        for thread in threads:
4434            thread.start()
4435        for thread in threads:
4436            thread.join()
4437
4438        # each thread should create a unique mempool, since
4439        # mempool id creation is atomic
4440        self.assertEqual(len(set(pool_ids)), 4)
4441
4442        # each thread should have different active mempool, since
4443        # the pointer to the mempool is thread local
4444        self.assertEqual(len(set(active_pool_ids)), 4)
4445
4446
4447@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4448@torch.testing._internal.common_utils.markDynamoStrictTest
4449class TestCudaOptims(TestCase):
4450    # These tests will be instantiate with instantiate_device_type_tests
4451    # to apply the new OptimizerInfo structure.
4452
4453    @onlyCUDA
4454    @unittest.skipIf(
4455        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
4456    )
4457    @optims(
4458        [optim for optim in optim_db if optim.has_capturable_arg],
4459        dtypes=[torch.float32],
4460    )
4461    def test_graph_optims(self, device, dtype, optim_info):
4462        optim_cls = optim_info.optim_cls
4463        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
4464            device, dtype, optim_info, skip=("differentiable",)
4465        )
4466
4467        steps_warmup = 3
4468        steps_train = 2
4469
4470        for optim_input in all_optim_inputs:
4471            kwargs = optim_input.kwargs
4472
4473            # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
4474            # and torch.optim.adamw
4475            kwargs["lr"] = 0.1
4476
4477            for actually_do_graphs in (True, False):
4478                params = [
4479                    torch.randn((i + 5, i + 5), device=device) for i in range(2)
4480                ] + [torch.randn((), device=device)]
4481                params_control = [p.clone().requires_grad_() for p in params]
4482                params_graphed = [p.clone().requires_grad_() for p in params]
4483
4484                grads = [
4485                    [torch.randn_like(p) for p in params]
4486                    for _ in range(steps_warmup + steps_train)
4487                ]
4488
4489                # Control (capturable=False)
4490                kwargs["capturable"] = False
4491
4492                opt = optim_cls(params_control, **kwargs)
4493                for i in range(steps_warmup + steps_train):
4494                    for j, p in enumerate(params_control):
4495                        p.grad = grads[i][j]
4496                    opt.step()
4497
4498                # capturable=True
4499                kwargs["capturable"] = True
4500                opt = optim_cls(params_graphed, **kwargs)
4501
4502                for i in range(steps_warmup):
4503                    for j, p in enumerate(params_graphed):
4504                        p.grad = grads[i][j]
4505                    opt.step()
4506
4507                if actually_do_graphs:
4508                    g = torch.cuda.CUDAGraph()
4509                    with torch.cuda.graph(g):
4510                        opt.step()
4511
4512                for i in range(steps_train):
4513                    if actually_do_graphs:
4514                        for j, p in enumerate(params_graphed):
4515                            p.grad.copy_(grads[i + steps_warmup][j])
4516                        g.replay()
4517                    else:
4518                        # Passing capturable=True to the constructor and running without graphs should still be
4519                        # numerically correct, even if it's not ideal for performance.
4520                        for j, p in enumerate(params_graphed):
4521                            p.grad = grads[i + steps_warmup][j]
4522                        opt.step()
4523
4524                for p_control, p_graphed in zip(params_control, params_graphed):
4525                    self.assertEqual(p_control, p_graphed)
4526
4527    @onlyCUDA
4528    @unittest.skipIf(
4529        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
4530    )
4531    @optims(
4532        [
4533            optim
4534            for optim in optim_db
4535            if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on
4536        ],
4537        dtypes=[torch.float32],
4538    )
4539    def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
4540        optim_cls = optim_info.optim_cls
4541
4542        steps_warmup = 3
4543        steps_train = 2
4544
4545        optim_inputs = optim_info.optim_inputs_func(device=device)
4546
4547        for optim_input in optim_inputs:
4548            kwargs = optim_input.kwargs
4549            kwargs["fused"] = True
4550
4551            for actually_do_graphs in (
4552                (True, False) if optim_info.has_capturable_arg else (True,)
4553            ):
4554                params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
4555                params_control = [p.clone().requires_grad_() for p in params]
4556                params_graphed = [p.clone().requires_grad_() for p in params]
4557
4558                # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
4559                grads = [
4560                    [torch.randn_like(p) for p in params]
4561                    for _ in range(steps_warmup + steps_train)
4562                ]
4563                with torch.no_grad():
4564                    grads_control = [[g.clone() for g in gs] for gs in grads]
4565                    grads_graphed = [[g.clone() for g in gs] for gs in grads]
4566
4567                # Gradient Scaler
4568                scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
4569                with torch.no_grad():
4570                    scaler_for_control._lazy_init_scale_growth_tracker(device)
4571
4572                scaler_for_graphed = torch.cuda.amp.GradScaler()
4573                scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
4574                with torch.no_grad():
4575                    scaler_for_graphed._lazy_init_scale_growth_tracker(device)
4576
4577                # Control (capturable=False)
4578                if optim_info.has_capturable_arg:
4579                    kwargs["capturable"] = False
4580                opt = optim_cls(params_control, **kwargs)
4581
4582                for i in range(steps_warmup + steps_train):
4583                    for j, p in enumerate(params_control):
4584                        p.grad = grads_control[i][j]
4585                    scaler_for_control.step(opt)
4586                    scaler_for_control.update()
4587
4588                # capturable=True
4589                if optim_info.has_capturable_arg:
4590                    kwargs["capturable"] = True
4591                opt = optim_cls(params_graphed, **kwargs)
4592
4593                for i in range(steps_warmup):
4594                    for j, p in enumerate(params_graphed):
4595                        p.grad = grads_graphed[i][j]
4596                    scaler_for_graphed.step(opt)
4597                    scaler_for_graphed.update()
4598
4599                if actually_do_graphs:
4600                    g = torch.cuda.CUDAGraph()
4601                    with torch.cuda.graph(g):
4602                        scaler_for_graphed.step(opt)
4603                        scaler_for_graphed.update()
4604
4605                for i in range(steps_train):
4606                    if actually_do_graphs:
4607                        for j, p in enumerate(params_graphed):
4608                            p.grad.copy_(grads_graphed[i + steps_warmup][j])
4609                        g.replay()
4610                    else:
4611                        # Passing capturable=True to the constructor and running without graphs should still be
4612                        # numerically correct, even if it's not ideal for performance.
4613                        for j, p in enumerate(params_graphed):
4614                            p.grad = grads_graphed[i + steps_warmup][j]
4615                        scaler_for_graphed.step(opt)
4616                        scaler_for_graphed.update()
4617
4618                for p_control, p_graphed in zip(params_control, params_graphed):
4619                    self.assertEqual(p_control, p_graphed)
4620
4621    @onlyNativeDeviceTypes
4622    @optims(
4623        [optim for optim in optim_db if "fused" in optim.supported_impls],
4624        dtypes=[torch.float32],
4625    )
4626    def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
4627        device = device.split(":")[0]
4628        if device not in optim_info.supports_fused_on:
4629            self.skipTest(
4630                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
4631            )
4632        optim_inputs = optim_info.optim_inputs_func(device=device)
4633        optim_cls = optim_info.optim_cls
4634        for optim_input in optim_inputs:
4635            for _separate_unscale in (True, False):
4636                kwargs = optim_input.kwargs
4637                kwargs["fused"] = True
4638                torch.manual_seed(20)
4639                (
4640                    mod_control,
4641                    mod_scaling,
4642                    opt_control,
4643                    opt_scaling,
4644                    data,
4645                    loss_fn,
4646                    _,
4647                ) = _create_scaling_case(
4648                    optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device
4649                )
4650                optimizer_kwargs = deepcopy(kwargs)
4651                optimizer_kwargs["fused"] = False
4652                if "lr" not in kwargs:
4653                    # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
4654                    optimizer_kwargs["lr"] = 1.0
4655                opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs)
4656                scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
4657                scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
4658                tracker = TensorTracker()
4659                for input, target in data:
4660                    opt_control.zero_grad()
4661                    with torch.autocast(device_type=device, dtype=torch.half):
4662                        output_control = mod_control(input)
4663                        loss_control = loss_fn(output_control, target)
4664                    scaler_control.scale(loss_control).backward()
4665                    scaler_control.step(opt_control)
4666                    scaler_control.update()
4667
4668                    opt_scaling.zero_grad()
4669                    with torch.autocast(device_type=device, dtype=torch.half):
4670                        output_scaling = mod_scaling(input)
4671                        loss_scaling = loss_fn(output_scaling, target)
4672                    scaler_scaling.scale(loss_scaling).backward()
4673                    if _separate_unscale:
4674                        scaler_scaling.unscale_(opt_scaling)
4675                    scaler_scaling.step(opt_scaling)
4676                    scaler_scaling.update()
4677
4678                    tracker.add(loss_control)
4679                    tracker.pop_check_set(loss_scaling, self)
4680                    for param_control, param_scaling in zip(
4681                        mod_control.parameters(), mod_scaling.parameters()
4682                    ):
4683                        tracker.add(param_control.grad)
4684                        tracker.pop_check_set(param_scaling.grad, self)
4685                        tracker.add(param_control)
4686                        tracker.pop_check_set(param_scaling, self)
4687
4688                        state_control, state_scaling = (
4689                            opt_control.state[param_control],
4690                            opt_scaling.state[param_scaling],
4691                        )
4692
4693                        for k in state_control:
4694                            actual = state_scaling[k]
4695                            if k == "step":
4696                                actual = actual.squeeze()
4697                            tracker.add(state_control[k])
4698                            tracker.pop_check_set(actual, self)
4699
4700    @onlyCUDA
4701    @parametrize("in_place_unscale", [False, True])
4702    @optims(
4703        [optim for optim in optim_db if "cuda" in optim.supports_fused_on],
4704        dtypes=[torch.float32],
4705    )
4706    def test_grad_scaler_with_preset_grad_scale(
4707        self, device, dtype, optim_info, in_place_unscale
4708    ):
4709        weight = torch.ones((5, 5), device="cuda", requires_grad=True)
4710        weight.grad = torch.full_like(weight, fill_value=15)
4711        opt = optim_info.optim_cls([weight], lr=0.1, fused=True)
4712        scaler = torch.amp.GradScaler(init_scale=5)
4713
4714        # simulate scaling a loss
4715        scaler.scale(torch.ones(5))
4716
4717        if in_place_unscale:
4718            scaler.unscale_(opt)
4719            # the gradient should have been divided in-place
4720            self.assertEqual(weight.grad, torch.full_like(weight, fill_value=3))
4721
4722        # the user sets a `grad_scale` value which should be fused with the optimizer step
4723        opt.grad_scale = torch.Tensor([3]).cuda()
4724        scaler.step(opt)
4725
4726        # check that the user's grad_scale was respected (i.e. the gradient was divided by 5 * 3)
4727        self.assertEqual(weight.grad, torch.full_like(weight, fill_value=1))
4728
4729    @onlyCUDA
4730    @unittest.skipIf(
4731        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
4732    )
4733    @parametrize("foreach, fused", [(False, False), (True, False), (False, True)])
4734    @optims(
4735        [
4736            optim
4737            for optim in optim_db
4738            if "foreach" in optim.supported_impls and "cuda" in optim.supports_fused_on
4739        ],
4740        dtypes=[torch.float32],
4741    )
4742    def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused):
4743        torch.cuda.empty_cache()
4744
4745        scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0)
4746        g = torch.cuda.CUDAGraph()
4747        s = torch.cuda.Stream()
4748
4749        weight = torch.ones((100,), device="cuda", requires_grad=True)
4750        opt = optim_info.optim_cls([weight], lr=0.1, foreach=foreach, fused=fused)
4751        static_input = torch.ones_like(weight)
4752        static_grad = torch.ones_like(weight)
4753
4754        # warmup
4755        s = torch.cuda.Stream()
4756        s.wait_stream(torch.cuda.current_stream())
4757        with torch.cuda.stream(s):
4758            loss = (weight.half() * static_input).sum()
4759            scaler.scale(loss).backward()
4760        torch.cuda.current_stream().wait_stream(s)
4761
4762        opt.zero_grad(set_to_none=True)
4763
4764        # capture
4765        with torch.cuda.stream(s):
4766            g.capture_begin()
4767            loss = (weight.half() * static_input).sum()
4768            scaler.scale(loss).backward()
4769            g.capture_end()
4770
4771        input_vals = [5, 20000, 5, 40000]
4772        # If the scale gets updated properly, these are the scale, growth tracker,
4773        # and grad values we expect.
4774        expected_scales = [4, 2, 2, 1]
4775        expected_growth_trackers = [1, 0, 1, 0]
4776        expected_grad_vals = [5 * 4, float("inf"), 5 * 2, float("inf")]
4777
4778        for data, scale, growth_tracker, grad_val in zip(
4779            input_vals, expected_scales, expected_growth_trackers, expected_grad_vals
4780        ):
4781            static_input.fill_(data)
4782            g.replay()
4783            self.assertEqual(weight.grad, torch.full_like(weight.grad, grad_val))
4784            scaler.step(opt)
4785            scaler.update()
4786            self.assertEqual(scaler._scale, scale)
4787            self.assertEqual(scaler._growth_tracker, growth_tracker)
4788
4789
4790@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4791class TestGDS(TestCase):
4792    def _get_tmp_dir_fs_type(self):
4793        my_path = os.path.realpath("/tmp")
4794        root_type = ""
4795        for part in psutil.disk_partitions():
4796            if part.mountpoint == "/":
4797                root_type = part.fstype
4798                continue
4799            if part.mountpoint == my_path:
4800                return part.fstype
4801        return root_type
4802
4803    @unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
4804    def test_gds_read_write_tensors(self):
4805        if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
4806            self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
4807        src1 = torch.randn(1024, device="cuda")
4808        src2 = torch.randn(2, 1024, device="cuda")
4809        torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
4810        torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
4811        dest1 = torch.empty(1024, device="cuda")
4812        dest2 = torch.empty(2, 1024, device="cuda")
4813        with TemporaryFileName() as f:
4814            file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
4815            file.save_storage(src1.untyped_storage(), offset=0)
4816            file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
4817            file.load_storage(dest1.untyped_storage(), offset=0)
4818            file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
4819        self.assertEqual(src1, dest1)
4820        self.assertEqual(src2, dest2)
4821        torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
4822        torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
4823
4824
4825@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
4826class TestCudaAutocast(TestAutocast):
4827    def setUp(self):
4828        super().setUp()
4829        self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
4830
4831    def tearDown(self):
4832        del self.autocast_lists
4833        super().tearDown()
4834
4835    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4836    def test_autocast_torch_fp16(self):
4837        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4838            for op_with_args in self.autocast_lists.torch_fp16:
4839                skip_test = False
4840                op, args = op_with_args[0], op_with_args[1]
4841                if len(op_with_args) == 3:
4842                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
4843                if not skip_test:
4844                    self._run_autocast_outofplace(
4845                        op, args, torch.float16, device="cuda", amp_dtype=torch.float16
4846                    )
4847
4848    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4849    def test_autocast_torch_bf16(self):
4850        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4851            for op_with_args in self.autocast_lists.torch_fp16:
4852                skip_test = False
4853                op, args = op_with_args[0], op_with_args[1]
4854                if len(op_with_args) == 3:
4855                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
4856                should_error_from_cudnn = "cudnn" in op and (
4857                    "TORCH_CUDNN_V8_API_DISABLED" in os.environ
4858                    and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
4859                    or torch.cuda.get_device_capability() < (8, 0)
4860                )
4861                should_error_from_not_implemented = should_error_from_cudnn
4862                if not skip_test:
4863                    if should_error_from_not_implemented:
4864                        with self.assertRaises(
4865                            RuntimeError,
4866                            msg=str(op) + " should not be supported for bfloat16!",
4867                        ):
4868                            self._run_autocast_outofplace(
4869                                op, args, torch.bfloat16, device="cuda"
4870                            )
4871                    else:
4872                        if torch.cuda.is_bf16_supported():
4873                            self._run_autocast_outofplace(
4874                                op, args, torch.bfloat16, device="cuda"
4875                            )
4876                        else:
4877                            with self.assertRaisesRegex(
4878                                RuntimeError, "Device does not support bfloat16"
4879                            ):
4880                                self._run_autocast_outofplace(
4881                                    op, args, torch.bfloat16, device="cuda"
4882                                )
4883
4884    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4885    def test_autocast_torch_fp32(self):
4886        for op_with_args in self.autocast_lists.torch_fp32:
4887            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
4888            self._run_autocast_outofplace(
4889                op,
4890                args,
4891                torch.float32,
4892                device="cuda",
4893                add_kwargs=maybe_kwargs,
4894                amp_dtype=torch.float16,
4895            )
4896
4897    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4898    def test_autocast_torch_need_autocast_promote(self):
4899        for op, args in self.autocast_lists.torch_need_autocast_promote:
4900            self._run_autocast_outofplace(
4901                op, args, torch.float32, device="cuda", amp_dtype=torch.float16
4902            )
4903
4904    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4905    def test_autocast_torch_expect_builtin_promote(self):
4906        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
4907            self._run_autocast_outofplace(
4908                op,
4909                args,
4910                torch.float32,
4911                device="cuda",
4912                out_type=out_type,
4913                amp_dtype=torch.float16,
4914            )
4915
4916    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4917    def test_autocast_nn_fp16(self):
4918        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4919            for op, args in self.autocast_lists.nn_fp16:
4920                self._run_autocast_outofplace(
4921                    op,
4922                    args,
4923                    torch.float16,
4924                    device="cuda",
4925                    module=torch._C._nn,
4926                    amp_dtype=torch.float16,
4927                )
4928
4929    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4930    def test_autocast_nn_bf16(self):
4931        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4932            for op, args in self.autocast_lists.nn_fp16:
4933                if torch.cuda.is_bf16_supported():
4934                    self._run_autocast_outofplace(
4935                        op, args, torch.bfloat16, device="cuda", module=torch._C._nn
4936                    )
4937                else:
4938                    with self.assertRaisesRegex(
4939                        RuntimeError, "Device does not support bfloat16"
4940                    ):
4941                        self._run_autocast_outofplace(
4942                            op, args, torch.bfloat16, device="cuda", module=torch._C._nn
4943                        )
4944
4945    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4946    def test_autocast_nn_fp32(self):
4947        for op, args in self.autocast_lists.nn_fp32:
4948            self._run_autocast_outofplace(
4949                op,
4950                args,
4951                torch.float32,
4952                device="cuda",
4953                module=torch._C._nn,
4954                amp_dtype=torch.float16,
4955            )
4956
4957    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4958    def test_autocast_linalg_fp16(self):
4959        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4960            for op, args in self.autocast_lists.linalg_fp16:
4961                self._run_autocast_outofplace(
4962                    op,
4963                    args,
4964                    torch.float16,
4965                    device="cuda",
4966                    module=torch._C._linalg,
4967                    amp_dtype=torch.float16,
4968                )
4969
4970    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4971    def test_autocast_methods_fp16(self):
4972        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
4973            for op, args in self.autocast_lists.methods_fp16:
4974                self._run_autocast_outofplace(
4975                    op,
4976                    args,
4977                    torch.float16,
4978                    device="cuda",
4979                    module=None,
4980                    amp_dtype=torch.float16,
4981                )
4982
4983    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4984    def test_autocast_methods_fp32(self):
4985        for op, args in self.autocast_lists.methods_fp32:
4986            self._run_autocast_outofplace(
4987                op,
4988                args,
4989                torch.float32,
4990                device="cuda",
4991                module=None,
4992                amp_dtype=torch.float16,
4993            )
4994
4995    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
4996    def test_autocast_methods_expect_builtin_promote(self):
4997        for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
4998            self._run_autocast_outofplace(
4999                op,
5000                args,
5001                torch.float32,
5002                device="cuda",
5003                module=None,
5004                out_type=out_type,
5005                amp_dtype=torch.float16,
5006            )
5007
5008    def test_autocast_banned(self):
5009        with torch.autocast("cuda"):
5010            for op, args, module in self.autocast_lists.banned:
5011                with self.assertRaises(RuntimeError):
5012                    getattr(module, op)(*args)
5013
5014    def test_autocast_ignored_types(self):
5015        with torch.autocast("cuda"):
5016            for ignore_type in (torch.double, torch.int32):
5017                a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
5018                b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
5019                c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
5020
5021                # Tests if CastPolicy::fp16 ops ignore double and int
5022                # Currently, no ops belonging to this policy support integer inputs.
5023                if ignore_type is torch.double:
5024                    with self.assertRaises(RuntimeError):
5025                        torch.mm(a_ignore, c_16)
5026                    with torch.autocast("cuda", enabled=False):
5027                        type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
5028                    self.assertTrue(
5029                        torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
5030                    )
5031
5032                # Tests if CastPolicy::fp32 ops ignore double and int
5033                with torch.autocast("cuda", enabled=False):
5034                    type_no_autocast = torch.pow(a_ignore, 2.0).dtype
5035                self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
5036
5037                # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
5038                with torch.autocast("cuda", enabled=False):
5039                    type_no_autocast = torch.sum(a_ignore).dtype
5040                self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
5041
5042                # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
5043                # Currently, no ops belonging to this policy support integer inputs.
5044                if ignore_type is torch.double:
5045                    with torch.autocast("cuda", enabled=False):
5046                        type_no_autocast = torch.norm(a_ignore).dtype
5047                    self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
5048
5049    def test_autocast_custom_enabled(self):
5050        class MyMM(torch.autograd.Function):
5051            @staticmethod
5052            @torch.amp.custom_fwd(device_type="cuda")
5053            def forward(ctx, a, b):
5054                self.assertTrue(a.dtype is torch.float32)
5055                self.assertTrue(b.dtype is torch.float32)
5056                self.assertTrue(torch.is_autocast_enabled())
5057                ctx.save_for_backward(a, b)
5058                return a.mm(b)
5059
5060            @staticmethod
5061            @torch.amp.custom_bwd(device_type="cuda")
5062            def backward(ctx, grad):
5063                self.assertTrue(torch.is_autocast_enabled())
5064                a, b = ctx.saved_tensors
5065                a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
5066                self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
5067                return a_grad, b_grad
5068
5069        mymm = MyMM.apply
5070
5071        x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
5072        y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
5073
5074        dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
5075        for dtype in dtypes:
5076            with torch.cuda.amp.autocast(dtype=dtype):
5077                output = mymm(x, y)
5078                self.assertTrue(output.dtype is dtype)
5079                loss = output.sum()
5080            loss.backward()
5081
5082    def test_autocast_custom_cast_inputs(self):
5083        class MyMM(torch.autograd.Function):
5084            @staticmethod
5085            @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
5086            def forward(ctx, a, container, expect_type):
5087                b = container[1][0]
5088                self.assertTrue(a.dtype is expect_type)
5089                self.assertTrue(b.dtype is expect_type)
5090                self.assertFalse(torch.is_autocast_enabled())
5091                ctx.save_for_backward(a, b)
5092                return a.mm(b)
5093
5094            @staticmethod
5095            @torch.amp.custom_bwd(device_type="cuda")
5096            def backward(ctx, grad):
5097                self.assertFalse(torch.is_autocast_enabled())
5098                a, b = ctx.saved_tensors
5099                return grad.mm(b.t()), None, None
5100
5101        mymm = MyMM.apply
5102
5103        x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
5104        # Puts one input tensor in a nested container.  y's contained Tensor won't receive a gradient,
5105        # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
5106        # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
5107        y = (
5108            0,
5109            {
5110                0: torch.randn(
5111                    (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
5112                )
5113            },
5114        )
5115
5116        with torch.autocast("cuda"):
5117            output = mymm(x, y, torch.float32)
5118            self.assertTrue(output.dtype is torch.float32)
5119            loss = output.sum()
5120        loss.backward()
5121
5122        # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
5123        output = mymm(x, y, torch.float16)
5124        self.assertTrue(output.dtype is torch.float16)
5125        loss = output.sum()
5126        loss.backward()
5127
5128    def test_autocast_custom_deprecated_warning(self):
5129        with warnings.catch_warnings(record=True) as w:
5130
5131            class MyMM(torch.autograd.Function):
5132                @staticmethod
5133                @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
5134                def forward(ctx, x, y):
5135                    ctx.save_for_backward(x, y)
5136                    self.assertFalse(torch.is_autocast_enabled())
5137                    return x + y
5138
5139                @staticmethod
5140                @torch.cuda.amp.custom_bwd
5141                def backward(ctx, grad):
5142                    _, _ = ctx.saved_tensors
5143                    self.assertFalse(torch.is_autocast_enabled())
5144                    return grad, grad
5145
5146        self.assertRegex(
5147            str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
5148        )
5149        self.assertRegex(
5150            str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
5151        )
5152
5153        mymm = MyMM.apply
5154        x = torch.randn(3, 3, requires_grad=True)
5155        y = torch.randn(3, 3, requires_grad=True)
5156        with torch.amp.autocast("cuda"):
5157            output = mymm(x, y)
5158            loss = output.sum()
5159        loss.backward()
5160
5161    def test_autocast_cat_jit(self):
5162        # Reported at https://github.com/pytorch/pytorch/issues/38958
5163
5164        class Model(torch.nn.Module):
5165            def forward(self):
5166                a = torch.randn(1)
5167                b = torch.randn(1)
5168                c = torch.cat((a, b), 0)
5169                d = torch.stack([c, c], 0)
5170                return d
5171
5172        # The JIT here doesn't really matter, we just need to call
5173        # cat via the boxed API
5174        model = Model()
5175        model_jit_script = torch.jit.script(model)
5176
5177        with torch.autocast("cuda", enabled=True):
5178            model()
5179            model_jit_script()
5180
5181    # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
5182    # so they get a dedicated test.
5183    # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
5184    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
5185    def test_autocast_rnn(self):
5186        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
5187            # seq, batch, features, hidden size
5188            clses = ("RNN", "GRU", "LSTM")
5189            T, B, F, H = 3, 4, 5, 6
5190            dtypes = (torch.float16, torch.float32)
5191            input_layouts = ("seq_first", "batch_first", "packed")
5192
5193            for (
5194                cls,
5195                num_layers,
5196                bias,
5197                input_layout,
5198                bidirectional,
5199                try_nonpreflattened_weights,
5200                input_dtype,
5201                hidden_dtype,
5202                weight_dtype,
5203            ) in product(
5204                clses,
5205                (1, 2),
5206                (True, False),
5207                input_layouts,
5208                (True, False),
5209                (True, False),
5210                dtypes,
5211                dtypes,
5212                dtypes,
5213            ):
5214                if input_layout == "seq_first":
5215                    batch_first = False
5216                    x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
5217                elif input_layout == "batch_first":
5218                    batch_first = True
5219                    x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
5220                elif input_layout == "packed":
5221                    batch_first = False
5222                    x = torch.nn.utils.rnn.pack_padded_sequence(
5223                        torch.randn((T, B, F), device="cuda", dtype=input_dtype),
5224                        lengths=(3, 2, 1, 3),
5225                        enforce_sorted=False,
5226                    )
5227
5228                rnn = (
5229                    getattr(torch.nn, cls)(
5230                        F,
5231                        H,
5232                        num_layers=num_layers,
5233                        bidirectional=bidirectional,
5234                        bias=bias,
5235                        batch_first=batch_first,
5236                    )
5237                    .cuda()
5238                    .to(dtype=weight_dtype)
5239                )
5240
5241                if try_nonpreflattened_weights:
5242                    for p in rnn.parameters():
5243                        with torch.no_grad():
5244                            p.set_(p.clone())
5245
5246                h = torch.randn(
5247                    (num_layers * (2 if bidirectional else 1), B, H),
5248                    device="cuda",
5249                    dtype=hidden_dtype,
5250                )
5251                if cls == "LSTM":
5252                    c = torch.randn(
5253                        (num_layers * (2 if bidirectional else 1), B, H),
5254                        device="cuda",
5255                        dtype=hidden_dtype,
5256                    )
5257                    h = (h, c)
5258
5259                with torch.autocast("cuda"):
5260                    out, h_out = rnn(x, h)
5261                out = out.data if input_layout == "packed" else out
5262                self.assertEqual(out.dtype, torch.float16)
5263                # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed.  This check can't guarantee
5264                # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
5265                # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
5266                self.assertEqual(
5267                    out.grad_fn.name(),
5268                    "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
5269                )
5270                out.sum().backward()
5271                grads = [p.grad.clone() for p in rnn.parameters()]
5272
5273                rnn.zero_grad()
5274
5275                if cls == "LSTM":
5276                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
5277                        x.half(), (h[0].half(), h[1].half())
5278                    )
5279                else:
5280                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
5281                        x.half(), h.half()
5282                    )
5283                out_control = (
5284                    out_control.data if input_layout == "packed" else out_control
5285                )
5286                out_control.sum().backward()
5287                grads_control = [p.grad.clone() for p in rnn.parameters()]
5288
5289                # Compares with default tolerances, even for FP16 execution.  Barring nondeterminism,
5290                # autocast and control results should be bitwise identical.
5291                self.assertEqual(out, out_control)
5292
5293                if cls == "LSTM":
5294                    self.assertTrue(
5295                        h_out[0].dtype is torch.float16
5296                        and h_out[1].dtype is torch.float16
5297                    )
5298                    self.assertEqual(h_out[0], h_out_control[0])
5299                    self.assertEqual(h_out[1], h_out_control[1])
5300                else:
5301                    self.assertEqual(h_out.dtype, torch.float16)
5302                    self.assertEqual(h_out, h_out_control)
5303                for grad, grad_control in zip(grads, grads_control):
5304                    self.assertEqual(grad.half(), grad_control)
5305
5306    def test_autocast_cache_leak(self):
5307        # Reported at https://github.com/pytorch/pytorch/issues/48049
5308        # Test is used to check, if autocast recaches the same parameters
5309        # when executed in a `torch.no_grad()` block.
5310
5311        linear = torch.nn.Linear(10, 10).to("cuda")
5312        data = torch.randn(1, 10, device="cuda")
5313
5314        with torch.autocast("cuda"):
5315            with torch.no_grad():
5316                out = linear(data)
5317                first_iter_mem = torch.cuda.memory_allocated()
5318                for _ in range(3):
5319                    out = linear(data)
5320                self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
5321
5322    def test_autocast_checkpointing(self):
5323        model = torch.nn.Sequential(
5324            torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
5325        ).cuda()
5326        input = torch.rand(
5327            (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
5328        )
5329        for reentrant in (True, False):
5330            with torch.autocast("cuda"):
5331                output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
5332            self.assertTrue(output.requires_grad)
5333            self.assertTrue(output.dtype is torch.float16)
5334            output.sum().backward()
5335
5336    def test_cuda_autocast_deprecated_warning(self):
5337        with self.assertWarnsRegex(
5338            FutureWarning,
5339            r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
5340        ):
5341            with torch.cuda.amp.autocast():
5342                _ = torch.ones(10)
5343
5344
5345instantiate_parametrized_tests(TestCuda)
5346instantiate_parametrized_tests(TestCudaMallocAsync)
5347instantiate_device_type_tests(TestCudaOptims, globals())
5348
5349if __name__ == "__main__":
5350    run_tests()
5351