xref: /aosp_15_r20/external/pytorch/test/test_cuda_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2
3import sys
4import unittest
5import unittest.mock
6
7import torch
8import torch.cuda._gpu_trace as gpu_trace
9from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase
10
11
12# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
13# and make sure the CUDA streams are initialized for each test that uses them.
14
15if not TEST_CUDA:
16    print("CUDA not available, skipping tests", file=sys.stderr)
17    TestCase = NoTest  # noqa: F811
18
19
20@torch.testing._internal.common_utils.markDynamoStrictTest
21class TestCudaTrace(TestCase):
22    def setUp(self):
23        torch._C._activate_gpu_trace()
24        self.mock = unittest.mock.MagicMock()
25
26    def test_event_creation_callback(self):
27        gpu_trace.register_callback_for_event_creation(self.mock)
28
29        event = torch.cuda.Event()
30        event.record()
31        self.mock.assert_called_once_with(event._as_parameter_.value)
32
33    def test_event_deletion_callback(self):
34        gpu_trace.register_callback_for_event_deletion(self.mock)
35
36        event = torch.cuda.Event()
37        event.record()
38        event_id = event._as_parameter_.value
39        del event
40        self.mock.assert_called_once_with(event_id)
41
42    def test_event_record_callback(self):
43        gpu_trace.register_callback_for_event_record(self.mock)
44
45        event = torch.cuda.Event()
46        event.record()
47        self.mock.assert_called_once_with(
48            event._as_parameter_.value, torch.cuda.default_stream().cuda_stream
49        )
50
51    def test_event_wait_callback(self):
52        gpu_trace.register_callback_for_event_wait(self.mock)
53
54        event = torch.cuda.Event()
55        event.record()
56        event.wait()
57        self.mock.assert_called_once_with(
58            event._as_parameter_.value, torch.cuda.default_stream().cuda_stream
59        )
60
61    def test_memory_allocation_callback(self):
62        gpu_trace.register_callback_for_memory_allocation(self.mock)
63
64        tensor = torch.empty(10, 4, device="cuda")
65        self.mock.assert_called_once_with(tensor.data_ptr())
66
67    def test_memory_deallocation_callback(self):
68        gpu_trace.register_callback_for_memory_deallocation(self.mock)
69
70        tensor = torch.empty(3, 8, device="cuda")
71        data_ptr = tensor.data_ptr()
72        del tensor
73        self.mock.assert_called_once_with(data_ptr)
74
75    def test_stream_creation_callback(self):
76        gpu_trace.register_callback_for_stream_creation(self.mock)
77
78        # see Note [HIP Lazy Streams]
79        if torch.version.hip:
80            user_stream = torch.cuda.Stream()
81            with torch.cuda.stream(user_stream):
82                tensor = torch.ones(5, device="cuda")
83        else:
84            torch.cuda.Stream()
85
86        self.mock.assert_called()
87
88    def test_device_synchronization_callback(self):
89        gpu_trace.register_callback_for_device_synchronization(self.mock)
90
91        torch.cuda.synchronize()
92        self.mock.assert_called()
93
94    def test_stream_synchronization_callback(self):
95        gpu_trace.register_callback_for_stream_synchronization(self.mock)
96
97        stream = torch.cuda.Stream()
98        stream.synchronize()
99        self.mock.assert_called_once_with(stream.cuda_stream)
100
101    def test_event_synchronization_callback(self):
102        gpu_trace.register_callback_for_event_synchronization(self.mock)
103
104        event = torch.cuda.Event()
105        event.record()
106        event.synchronize()
107        self.mock.assert_called_once_with(event._as_parameter_.value)
108
109    def test_memcpy_synchronization(self):
110        gpu_trace.register_callback_for_stream_synchronization(self.mock)
111
112        tensor = torch.rand(5, device="cuda")
113        tensor.nonzero()
114        self.mock.assert_called_once_with(torch.cuda.default_stream().cuda_stream)
115
116    def test_all_trace_callbacks_called(self):
117        other = unittest.mock.MagicMock()
118        gpu_trace.register_callback_for_memory_allocation(self.mock)
119        gpu_trace.register_callback_for_memory_allocation(other)
120
121        tensor = torch.empty(10, 4, device="cuda")
122        self.mock.assert_called_once_with(tensor.data_ptr())
123        other.assert_called_once_with(tensor.data_ptr())
124
125
126if __name__ == "__main__":
127    run_tests()
128