xref: /aosp_15_r20/external/pytorch/torch/autograd/profiler_legacy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import warnings
4from typing_extensions import deprecated
5
6import torch
7import torch.cuda
8from torch.autograd import (
9    _disable_profiler_legacy,
10    _enable_profiler_legacy,
11    DeviceType,
12    ProfilerConfig,
13    ProfilerState,
14)
15from torch.autograd.profiler_util import (
16    _filter_name,
17    _filter_stack_entry,
18    _rewrite_name,
19    EventList,
20    FunctionEvent,
21    MEMORY_EVENT_NAME,
22)
23
24
25__all__ = ["profile"]
26
27
28@deprecated(
29    "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. "
30    "Please use `torch.profiler` instead.",
31    category=None,  # TODO: change to `FutureWarning`
32)
33class profile:
34    """DEPRECATED: use torch.profiler instead."""
35
36    def __init__(
37        self,
38        enabled=True,
39        *,
40        use_cuda=False,
41        record_shapes=False,
42        with_flops=False,
43        profile_memory=False,
44        with_stack=False,
45        with_modules=False,
46    ):
47        self.enabled: bool = enabled
48        if not self.enabled:
49            return
50        self.use_cuda = use_cuda
51        self.function_events = None
52        self.entered = False
53        self.record_shapes = record_shapes
54        self.with_flops = with_flops
55        self.record_shapes |= self.with_flops
56        self.profile_memory = profile_memory
57        self.with_stack = with_stack
58        self.with_modules = with_modules
59
60        if self.use_cuda and not torch.cuda.is_available():
61            warnings.warn(
62                "CUDA is not available, disabling CUDA profiling",
63                stacklevel=2,
64            )
65            self.use_cuda = False
66
67        if self.use_cuda:
68            self.profiler_kind = ProfilerState.CUDA
69        else:
70            self.profiler_kind = ProfilerState.CPU
71
72    def config(self):
73        return ProfilerConfig(
74            self.profiler_kind,
75            self.record_shapes,
76            self.profile_memory,
77            self.with_stack,
78            self.with_flops,
79            self.with_modules,
80            # avoid exposing _ExperimentalConfig this in legacy public API
81            torch._C._profiler._ExperimentalConfig(),
82        )
83
84    def __enter__(self):
85        if not self.enabled:
86            return
87        if self.entered:
88            raise RuntimeError("Profiler context manager is not reentrant")
89        self.entered = True
90        self._start_trace()
91        return self
92
93    def _start_trace(self):
94        _enable_profiler_legacy(self.config())
95
96    def __exit__(self, exc_type, exc_val, exc_tb):
97        if not self.enabled:
98            return
99        if self.use_cuda:
100            torch.cuda.synchronize()
101
102        records = _disable_profiler_legacy()
103        parsed_results = _parse_legacy_records(records)
104        self.function_events = EventList(
105            parsed_results,
106            use_device="cuda" if self.use_cuda else None,
107            profile_memory=self.profile_memory,
108            with_flops=self.with_flops,
109        )
110        self.function_events._build_tree()
111        return False
112
113    def __repr__(self):
114        if self.function_events is None:
115            return "<unfinished profiler_legacy.profile>"
116        return repr(self.function_events)
117
118    def __str__(self):
119        if self.function_events is None:
120            return "<unfinished profile.profiler_legacy.profile>"
121        return str(self.function_events)
122
123    def _check_finish(self):
124        if self.function_events is None:
125            raise RuntimeError("Profiler didn't finish running")
126
127    def table(
128        self,
129        sort_by=None,
130        row_limit=100,
131        max_src_column_width=75,
132        max_name_column_width=55,
133        max_shapes_column_width=80,
134        header=None,
135        top_level_events_only=False,
136    ):
137        self._check_finish()
138        assert self.function_events is not None
139        return self.function_events.table(
140            sort_by=sort_by,
141            row_limit=row_limit,
142            max_src_column_width=max_src_column_width,
143            max_name_column_width=max_name_column_width,
144            max_shapes_column_width=max_shapes_column_width,
145            header=header,
146            top_level_events_only=top_level_events_only,
147        )
148
149    table.__doc__ = EventList.table.__doc__
150
151    def export_chrome_trace(self, path):
152        self._check_finish()
153        assert self.function_events is not None
154        return self.function_events.export_chrome_trace(path)
155
156    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
157
158    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
159        self._check_finish()
160        assert self.function_events is not None, "Expected profiling results"
161        assert self.with_stack, "export_stacks() requires with_stack=True"
162        return self.function_events.export_stacks(path, metric)
163
164    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
165        self._check_finish()
166        assert self.function_events is not None, "Expected profiling results"
167        return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
168
169    key_averages.__doc__ = EventList.key_averages.__doc__
170
171    def total_average(self):
172        self._check_finish()
173        assert self.function_events is not None, "Expected profiling results"
174        return self.function_events.total_average()
175
176    total_average.__doc__ = EventList.total_average.__doc__
177
178    @property
179    def self_cpu_time_total(self):
180        """Return CPU time as the sum of self times across all events."""
181        self._check_finish()
182        assert self.function_events is not None
183        return self.function_events.self_cpu_time_total
184
185
186def _parse_legacy_records(thread_records):
187    def _get_record_key(record):
188        """Return a tuple for correlating start and end records in `_parse_legacy_records`."""
189        return (record.handle(), record.node_id())
190
191    next_id = 0
192    start_record = None
193    functions = []
194    record_stack = []
195
196    # '__start_profile' is not guaranteed to be first, so we must find it here
197    for record in itertools.chain.from_iterable(thread_records):
198        name = record.name()
199        if start_record is None and name == "__start_profile":
200            start_record = record
201
202    assert start_record is not None and not start_record.is_remote()
203
204    for thread_record_list in thread_records:
205        # accumulated memory allocations per handle
206        cpu_memory_allocs = {}
207        cuda_memory_allocs = {}
208        # ranges per handle
209        range_starts = {}
210
211        filtered_handles = set()
212        prev_record = None
213        for record in thread_record_list:
214            record_key = _get_record_key(record)
215            if _filter_name(record.name()) or record_key in filtered_handles:
216                filtered_handles.add(record_key)
217                continue
218
219            if record.kind() == "push":
220                # workaround to reduce double logging from operator
221                # wrappers and redispatch
222                if prev_record is not None:
223                    duplicate = (
224                        prev_record.name() == record.name()
225                        and prev_record.kind() == record.kind()
226                        and prev_record.node_id() == record.node_id()
227                    )
228                    if duplicate:
229                        filtered_handles.add(record_key)
230                        continue
231
232                range_starts[record_key] = record
233                cpu_memory_allocs[record_key] = 0
234                cuda_memory_allocs[record_key] = 0
235            elif record.kind() == "pop":
236                assert (
237                    record_key in range_starts
238                ), f"""Expected record with key {record_key} to exist in range_starts.
239                    This means that the pop event did not have a corresponding push."""
240
241                start = range_starts[record_key]
242
243                cpu_memory_usage = cpu_memory_allocs[record_key]
244                cuda_memory_usage = cuda_memory_allocs[record_key]
245                is_async = start.is_async() or (start.thread_id() != record.thread_id())
246                is_remote_event = record.is_remote()
247                start_flops = start.flops()
248
249                fe = FunctionEvent(
250                    id=record.handle(),
251                    node_id=record.node_id(),
252                    name=_rewrite_name(name=start.name(), with_wildcard=True),
253                    trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
254                    thread=start.thread_id(),
255                    start_us=start_record.cpu_elapsed_us(start),
256                    end_us=start_record.cpu_elapsed_us(record),
257                    fwd_thread=start.fwd_thread_id(),
258                    input_shapes=start.shapes(),
259                    stack=[
260                        entry for entry in start.stack() if _filter_stack_entry(entry)
261                    ],
262                    scope=start.scope(),
263                    use_device="cuda" if start.has_cuda() else None,
264                    cpu_memory_usage=cpu_memory_usage,
265                    device_memory_usage=cuda_memory_usage,
266                    is_async=is_async,
267                    is_remote=is_remote_event,
268                    sequence_nr=start.sequence_nr(),
269                    device_type=DeviceType.CPU,
270                    is_legacy=True,
271                    flops=start_flops,
272                )
273                # note: async events have only cpu total time
274                if not is_async and start.has_cuda():
275                    duration = start.cuda_elapsed_us(record)
276                    if duration > 0:
277                        fe.append_kernel(start.name(), start.device(), duration)
278                functions.append(fe)
279                del range_starts[record_key]
280                del cpu_memory_allocs[record_key]
281                del cuda_memory_allocs[record_key]
282            elif record.kind() == "memory_alloc":
283                num_open_handles_cpu = len(cpu_memory_allocs)
284                num_open_handles_cuda = len(cuda_memory_allocs)
285                assert num_open_handles_cpu == num_open_handles_cuda
286                for handle in cpu_memory_allocs.keys():
287                    cpu_memory_allocs[handle] += record.cpu_memory_usage()
288                for handle in cuda_memory_allocs.keys():
289                    cuda_memory_allocs[handle] += record.cuda_memory_usage()
290                if num_open_handles_cpu == 0:
291                    # output event as a top-level memory event
292                    fe = FunctionEvent(
293                        id=0,
294                        name=MEMORY_EVENT_NAME,
295                        trace_name=None,
296                        thread=0,
297                        start_us=0,
298                        end_us=0,
299                        stack=[],
300                        cpu_memory_usage=record.cpu_memory_usage(),
301                        device_memory_usage=record.cuda_memory_usage(),
302                        is_legacy=True,
303                    )
304                    functions.append(fe)
305            prev_record = record
306
307    # Sort functions by start time then by end time ascending.
308    # This ensures that--in the case of nested events which
309    # have the same start time (which may happen due to the
310    # granularity of the given clock tick)--we always show
311    # the outermost nested call first. This adds stability
312    # in how FunctionEvents appear
313    functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
314    return functions
315