xref: /aosp_15_r20/external/pytorch/torch/profiler/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import operator
4import re
5from collections import deque
6from dataclasses import dataclass
7from typing import Dict, List, TYPE_CHECKING
8
9from torch.autograd.profiler import profile
10from torch.profiler import DeviceType
11
12
13if TYPE_CHECKING:
14    from torch.autograd import _KinetoEvent
15
16
17def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
18    order = reversed if reverse else lambda x: x
19    remaining = deque(order(tree))
20    while remaining:
21        curr_event = next_fn(remaining)
22        yield curr_event
23        for child_event in order(children_fn(curr_event)):
24            remaining.append(child_event)
25
26
27traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True)
28traverse_bfs = functools.partial(
29    _traverse, next_fn=lambda x: x.popleft(), reverse=False
30)
31
32
33@dataclass
34class EventMetrics:
35    duration_time_ns: int = 0
36    self_time_ns: int = 0
37    idle_time_ns: int = 0
38    queue_depth: int = 0
39
40    @property
41    def fraction_idle_time(self):
42        if self.duration_time_ns == 0:
43            return 0.0
44        return self.idle_time_ns / self.duration_time_ns
45
46
47@dataclass
48class Interval:
49    start: int
50    end: int
51    queue_depth: int = 0
52
53
54class EventKey:
55    def __init__(self, event):
56        self.event = event
57
58    def __hash__(self):
59        return hash(self.event.id)
60
61    def __eq__(self, other):
62        return self.event.id == other.event.id
63
64    def __repr__(self):
65        return f"{self.event.name}"
66
67    def intervals_overlap(self, intervals: List[Interval]):
68        overlap_time = 0
69        intervals = sorted(intervals, key=lambda x: x.start)
70
71        if intervals:
72            overlap_start = max(self.event.start_time_ns, intervals[0].start)
73            overlap_end = min(self.event.end_time_ns, intervals[0].end)
74
75            if overlap_start < overlap_end:
76                overlap_time += overlap_end - overlap_start
77
78        i, j = 0, 1
79        while j < len(intervals):
80            prev_interval = intervals[i]
81            curr_interval = intervals[j]
82            j += 1
83            if prev_interval.end > curr_interval.start:
84                # Completely subsumed by previous interval
85                if prev_interval.end > curr_interval.end:
86                    j += 1
87                    continue
88                else:
89                    curr_interval.start = prev_interval.end
90                    i = j
91
92            overlap_start = max(self.event.start_time_ns, curr_interval.start)
93            overlap_end = min(self.event.end_time_ns, curr_interval.end)
94            if overlap_start < overlap_end:
95                overlap_time += overlap_end - overlap_start
96
97        return overlap_time
98
99
100class BasicEvaluation:
101    def __init__(self, prof: profile):
102        self.profile = prof
103        self.metrics: Dict[EventKey, EventMetrics] = {}
104        self.compute_self_time()
105        self.event_keys = sorted(
106            (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns
107        )
108        self.events = [e.event for e in self.event_keys]
109        self.cuda_events: List[_KinetoEvent] = []
110        self.queue_depth_list = self.compute_queue_depth()
111        self.compute_idle_time()
112
113    def compute_self_time(self):
114        """
115        Computes event's self time(total time - time in child ops).
116        """
117        assert self.profile.kineto_results is not None
118        stack = deque(self.profile.kineto_results.experimental_event_tree())
119
120        # standard iterating dfs
121        while stack:
122            curr_event = stack.pop()
123            self_time = curr_event.duration_time_ns
124            for child_event in curr_event.children:
125                self_time -= child_event.duration_time_ns
126                stack.append(child_event)
127            assert (
128                EventKey(curr_event) not in self.metrics
129            ), f"Duplicate id: {curr_event.id}, {curr_event.name}"
130            self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
131            self.metrics[
132                EventKey(curr_event)
133            ].duration_time_ns = curr_event.duration_time_ns
134
135    def compute_queue_depth(self):
136        """
137        Computes queue_depth at each event. This will calculate the queue depth data for
138        All the events in the tree.
139        This will return a list of Interval of queue depth data of cuda launch and kernels.
140        """
141        assert self.profile.kineto_results is not None
142        cuda_event_list = self.profile.kineto_results.events()
143
144        def is_cuda_launch_kernel(e):
145            # TODO: find a better way to identify cudaLaunchKernel
146            return e.name == "cudaLaunchKernel"
147
148        def is_cuda_kernel(e):
149            # TODO: find a better way to identify CUDA Kernel
150            return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower()
151
152        cuda_launch_events = sorted(
153            (e for e in cuda_event_list if is_cuda_launch_kernel(e)),
154            key=lambda x: x.start_ns(),
155        )
156        cuda_kernel_events = sorted(
157            (e for e in cuda_event_list if is_cuda_kernel(e)),
158            key=lambda x: x.start_ns(),
159        )
160
161        self.cuda_events = sorted(
162            cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
163        )
164
165        kernel_mapping: Dict[_KinetoEvent, int] = {}
166        last_mapped_kernel = 0
167        for cuda_launch_event in cuda_launch_events:
168            index = index_of_first_match(
169                cuda_kernel_events,
170                lambda x: x.linked_correlation_id()
171                == cuda_launch_event.linked_correlation_id(),
172                start=last_mapped_kernel,
173            )
174            kernel_mapping[cuda_launch_event] = index
175            last_mapped_kernel = index if index is not None else last_mapped_kernel
176
177        current_kernel_index = 0
178        spawned_kernel_index = -1
179
180        all_events = cuda_launch_events + cuda_kernel_events + self.events
181
182        def new_old_event_comparator(event):
183            if hasattr(event, "start_us"):
184                return event.start_us() * 1000
185            if hasattr(event, "start_ns"):
186                return event.start_ns()
187            if hasattr(event, "start_time_ns"):
188                return event.start_time_ns
189            raise Exception("Unknown Event Type")  # noqa: TRY002
190
191        queue_depth_list: List[Interval] = []
192        all_events.sort(key=new_old_event_comparator)
193        for event in all_events:
194            # Find latest cuda kernel event
195            if hasattr(event, "start_us"):
196                start_time = event.start_us() * 1000
197                end_time = (event.start_us() + event.duration_us()) * 1000
198                # Find current spawned cuda kernel event
199                if event in kernel_mapping and kernel_mapping[event] is not None:
200                    spawned_kernel_index = kernel_mapping[event]
201            if hasattr(event, "start_ns"):
202                start_time = event.start_ns()
203                end_time = event.start_ns() + event.duration_ns()
204                # Find current spawned cuda kernel event
205                if event in kernel_mapping and kernel_mapping[event] is not None:
206                    spawned_kernel_index = kernel_mapping[event]
207            elif hasattr(event, "start_time_ns"):
208                start_time = event.start_time_ns  # type: ignore[attr-defined]
209                end_time = event.end_time_ns  # type: ignore[attr-defined]
210
211            while (
212                current_kernel_index < len(cuda_kernel_events)
213                and (cuda_kernel_events[current_kernel_index].start_ns())
214                <= start_time  # type: ignore[possibly-undefined]
215            ):
216                current_kernel_index += 1
217            current_queue_depth = spawned_kernel_index - current_kernel_index + 1
218            current_queue_depth = max(current_queue_depth, 0)
219
220            if hasattr(event, "start_us") or hasattr(event, "start_ns"):
221                queue_depth_list.append(
222                    Interval(start_time, end_time, current_queue_depth)  # type: ignore[possibly-undefined]
223                )
224            elif hasattr(event, "start_time_ns"):
225                self.metrics[EventKey(event)].queue_depth = current_queue_depth
226
227        return queue_depth_list
228
229    def compute_idle_time(self):
230        """
231        Computes idle time of the profile.
232        """
233        # Based on queue_depth_list, we can calculate idle time for all the events
234        idle = False
235        idle_start = 0
236        idle_intervals: List[Interval] = []
237        if self.queue_depth_list and self.events:
238            idle_intervals += [
239                Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start),
240                Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns),
241            ]
242
243        for data_point in self.queue_depth_list:
244            if data_point.queue_depth == 0 and not idle:
245                idle_start = data_point.end
246                idle = True
247            if data_point.queue_depth > 0 and idle:
248                idle_intervals.append(Interval(idle_start, data_point.start))
249                idle = False
250
251        event_list = [e.event for e in self.metrics.keys()]
252        for event in event_list:
253            self.metrics[EventKey(event)].idle_time_ns = EventKey(
254                event
255            ).intervals_overlap(idle_intervals)
256
257    def rank_events(self, length):
258        """
259        Filter and Rank the events based on some heuristics:
260        1) Events that are in the falling phase of the queue depth.
261        2) Events that have a high idle_time, self_time difference.
262
263        Parameters:
264            length: The number of events to return.
265        """
266
267        # Find the interval when qd is falling to 0
268        import torch
269
270        queue_depth_list = list(reversed(self.queue_depth_list))
271        qd_values = [e.queue_depth for e in queue_depth_list]
272
273        bottom_threashold = 0
274        top_threashold = 4
275        decrease_interval = []
276        i = 0
277        while i < len(qd_values):
278            if qd_values[i] > bottom_threashold:
279                i += 1
280                continue
281            for j in range(i + 1, len(qd_values)):
282                # Find next zero and if the max value between them exceeds
283                # the threshold, then we have a falling interval
284                next_minimum_idx = index_of_first_match(
285                    qd_values, lambda x: x <= bottom_threashold, start=j
286                )
287                peak_idx = argmax(qd_values, start=j, end=next_minimum_idx)
288
289                # if is a valid peak, we add to list and continue
290                if peak_idx is not None and qd_values[peak_idx] >= top_threashold:
291                    decrease_interval.append(
292                        Interval(
293                            queue_depth_list[peak_idx].start, queue_depth_list[i].start
294                        )
295                    )
296                    i = next_minimum_idx if next_minimum_idx is not None else i
297                    break
298            i += 1
299        # Filter out events that are not in the decrease interval
300        event_list = [
301            event
302            for event in self.metrics.keys()
303            if event.intervals_overlap(decrease_interval)
304        ]
305        if event_list:
306            self_time = torch.tensor(
307                [self.metrics[event].self_time_ns for event in event_list],
308                dtype=torch.float32,
309            )
310            idle_time = torch.tensor(
311                [self.metrics[event].fraction_idle_time for event in event_list],
312                dtype=torch.float32,
313            )
314            normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time)
315            normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time)
316            heuristic_score_list = normalized_gain + 0.6 * normalized_self
317
318            # Sort events by heuristic
319            event_list = [
320                event
321                for _, event in sorted(
322                    zip(heuristic_score_list, event_list),
323                    key=operator.itemgetter(0),
324                    reverse=True,
325                )
326            ]
327            event_list = event_list[:length]
328        return event_list
329
330    def get_optimizable_events(self, length: int = 1, print_enable: bool = True):
331        event_list = self.rank_events(length)
332        if not print_enable:
333            return event_list
334        output = "Optimizable events:\n" if event_list else "No events to optimize\n"
335
336        output += "\n".join(
337            [
338                f"""{'-'*80}
339Event:                {event}
340Source code location: {source_code_location(event.event)}
341Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
342{'-'*80}"""
343                for event in event_list
344            ]
345        )
346        if print_enable:
347            print(output)
348        return event_list
349
350
351def index_of_first_match(seq, predicate, start=0, end=None):
352    if end is None or end >= len(seq):
353        end = len(seq)
354    for i in range(start, end):
355        if predicate(seq[i]):
356            return i
357    return None
358
359
360def argmax(seq, key=lambda x: x, start=0, end=None):
361    seq = seq[start:end]
362    if len(seq) == 0:
363        return None
364    return seq.index(max(seq, key=key)) + start
365
366
367def source_code_location(event):
368    while event is not None:
369        match = re.search(r"\.py\(.*\)", event.name)
370        if match is None:
371            event = event.parent
372            continue
373        return event.name
374    return "No source code location found"
375
376
377# Provide an OSS workaround for cudagraphs + CUPTI issue
378# https://github.com/pytorch/pytorch/issues/75504
379# TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when
380# we stop supporting older CUDA versions.
381def _init_for_cuda_graphs():
382    from torch.autograd.profiler import profile
383
384    with profile():
385        pass
386