1# mypy: allow-untyped-defs 2import functools 3import operator 4import re 5from collections import deque 6from dataclasses import dataclass 7from typing import Dict, List, TYPE_CHECKING 8 9from torch.autograd.profiler import profile 10from torch.profiler import DeviceType 11 12 13if TYPE_CHECKING: 14 from torch.autograd import _KinetoEvent 15 16 17def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False): 18 order = reversed if reverse else lambda x: x 19 remaining = deque(order(tree)) 20 while remaining: 21 curr_event = next_fn(remaining) 22 yield curr_event 23 for child_event in order(children_fn(curr_event)): 24 remaining.append(child_event) 25 26 27traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True) 28traverse_bfs = functools.partial( 29 _traverse, next_fn=lambda x: x.popleft(), reverse=False 30) 31 32 33@dataclass 34class EventMetrics: 35 duration_time_ns: int = 0 36 self_time_ns: int = 0 37 idle_time_ns: int = 0 38 queue_depth: int = 0 39 40 @property 41 def fraction_idle_time(self): 42 if self.duration_time_ns == 0: 43 return 0.0 44 return self.idle_time_ns / self.duration_time_ns 45 46 47@dataclass 48class Interval: 49 start: int 50 end: int 51 queue_depth: int = 0 52 53 54class EventKey: 55 def __init__(self, event): 56 self.event = event 57 58 def __hash__(self): 59 return hash(self.event.id) 60 61 def __eq__(self, other): 62 return self.event.id == other.event.id 63 64 def __repr__(self): 65 return f"{self.event.name}" 66 67 def intervals_overlap(self, intervals: List[Interval]): 68 overlap_time = 0 69 intervals = sorted(intervals, key=lambda x: x.start) 70 71 if intervals: 72 overlap_start = max(self.event.start_time_ns, intervals[0].start) 73 overlap_end = min(self.event.end_time_ns, intervals[0].end) 74 75 if overlap_start < overlap_end: 76 overlap_time += overlap_end - overlap_start 77 78 i, j = 0, 1 79 while j < len(intervals): 80 prev_interval = intervals[i] 81 curr_interval = intervals[j] 82 j += 1 83 if prev_interval.end > curr_interval.start: 84 # Completely subsumed by previous interval 85 if prev_interval.end > curr_interval.end: 86 j += 1 87 continue 88 else: 89 curr_interval.start = prev_interval.end 90 i = j 91 92 overlap_start = max(self.event.start_time_ns, curr_interval.start) 93 overlap_end = min(self.event.end_time_ns, curr_interval.end) 94 if overlap_start < overlap_end: 95 overlap_time += overlap_end - overlap_start 96 97 return overlap_time 98 99 100class BasicEvaluation: 101 def __init__(self, prof: profile): 102 self.profile = prof 103 self.metrics: Dict[EventKey, EventMetrics] = {} 104 self.compute_self_time() 105 self.event_keys = sorted( 106 (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns 107 ) 108 self.events = [e.event for e in self.event_keys] 109 self.cuda_events: List[_KinetoEvent] = [] 110 self.queue_depth_list = self.compute_queue_depth() 111 self.compute_idle_time() 112 113 def compute_self_time(self): 114 """ 115 Computes event's self time(total time - time in child ops). 116 """ 117 assert self.profile.kineto_results is not None 118 stack = deque(self.profile.kineto_results.experimental_event_tree()) 119 120 # standard iterating dfs 121 while stack: 122 curr_event = stack.pop() 123 self_time = curr_event.duration_time_ns 124 for child_event in curr_event.children: 125 self_time -= child_event.duration_time_ns 126 stack.append(child_event) 127 assert ( 128 EventKey(curr_event) not in self.metrics 129 ), f"Duplicate id: {curr_event.id}, {curr_event.name}" 130 self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) 131 self.metrics[ 132 EventKey(curr_event) 133 ].duration_time_ns = curr_event.duration_time_ns 134 135 def compute_queue_depth(self): 136 """ 137 Computes queue_depth at each event. This will calculate the queue depth data for 138 All the events in the tree. 139 This will return a list of Interval of queue depth data of cuda launch and kernels. 140 """ 141 assert self.profile.kineto_results is not None 142 cuda_event_list = self.profile.kineto_results.events() 143 144 def is_cuda_launch_kernel(e): 145 # TODO: find a better way to identify cudaLaunchKernel 146 return e.name == "cudaLaunchKernel" 147 148 def is_cuda_kernel(e): 149 # TODO: find a better way to identify CUDA Kernel 150 return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower() 151 152 cuda_launch_events = sorted( 153 (e for e in cuda_event_list if is_cuda_launch_kernel(e)), 154 key=lambda x: x.start_ns(), 155 ) 156 cuda_kernel_events = sorted( 157 (e for e in cuda_event_list if is_cuda_kernel(e)), 158 key=lambda x: x.start_ns(), 159 ) 160 161 self.cuda_events = sorted( 162 cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns() 163 ) 164 165 kernel_mapping: Dict[_KinetoEvent, int] = {} 166 last_mapped_kernel = 0 167 for cuda_launch_event in cuda_launch_events: 168 index = index_of_first_match( 169 cuda_kernel_events, 170 lambda x: x.linked_correlation_id() 171 == cuda_launch_event.linked_correlation_id(), 172 start=last_mapped_kernel, 173 ) 174 kernel_mapping[cuda_launch_event] = index 175 last_mapped_kernel = index if index is not None else last_mapped_kernel 176 177 current_kernel_index = 0 178 spawned_kernel_index = -1 179 180 all_events = cuda_launch_events + cuda_kernel_events + self.events 181 182 def new_old_event_comparator(event): 183 if hasattr(event, "start_us"): 184 return event.start_us() * 1000 185 if hasattr(event, "start_ns"): 186 return event.start_ns() 187 if hasattr(event, "start_time_ns"): 188 return event.start_time_ns 189 raise Exception("Unknown Event Type") # noqa: TRY002 190 191 queue_depth_list: List[Interval] = [] 192 all_events.sort(key=new_old_event_comparator) 193 for event in all_events: 194 # Find latest cuda kernel event 195 if hasattr(event, "start_us"): 196 start_time = event.start_us() * 1000 197 end_time = (event.start_us() + event.duration_us()) * 1000 198 # Find current spawned cuda kernel event 199 if event in kernel_mapping and kernel_mapping[event] is not None: 200 spawned_kernel_index = kernel_mapping[event] 201 if hasattr(event, "start_ns"): 202 start_time = event.start_ns() 203 end_time = event.start_ns() + event.duration_ns() 204 # Find current spawned cuda kernel event 205 if event in kernel_mapping and kernel_mapping[event] is not None: 206 spawned_kernel_index = kernel_mapping[event] 207 elif hasattr(event, "start_time_ns"): 208 start_time = event.start_time_ns # type: ignore[attr-defined] 209 end_time = event.end_time_ns # type: ignore[attr-defined] 210 211 while ( 212 current_kernel_index < len(cuda_kernel_events) 213 and (cuda_kernel_events[current_kernel_index].start_ns()) 214 <= start_time # type: ignore[possibly-undefined] 215 ): 216 current_kernel_index += 1 217 current_queue_depth = spawned_kernel_index - current_kernel_index + 1 218 current_queue_depth = max(current_queue_depth, 0) 219 220 if hasattr(event, "start_us") or hasattr(event, "start_ns"): 221 queue_depth_list.append( 222 Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined] 223 ) 224 elif hasattr(event, "start_time_ns"): 225 self.metrics[EventKey(event)].queue_depth = current_queue_depth 226 227 return queue_depth_list 228 229 def compute_idle_time(self): 230 """ 231 Computes idle time of the profile. 232 """ 233 # Based on queue_depth_list, we can calculate idle time for all the events 234 idle = False 235 idle_start = 0 236 idle_intervals: List[Interval] = [] 237 if self.queue_depth_list and self.events: 238 idle_intervals += [ 239 Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start), 240 Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns), 241 ] 242 243 for data_point in self.queue_depth_list: 244 if data_point.queue_depth == 0 and not idle: 245 idle_start = data_point.end 246 idle = True 247 if data_point.queue_depth > 0 and idle: 248 idle_intervals.append(Interval(idle_start, data_point.start)) 249 idle = False 250 251 event_list = [e.event for e in self.metrics.keys()] 252 for event in event_list: 253 self.metrics[EventKey(event)].idle_time_ns = EventKey( 254 event 255 ).intervals_overlap(idle_intervals) 256 257 def rank_events(self, length): 258 """ 259 Filter and Rank the events based on some heuristics: 260 1) Events that are in the falling phase of the queue depth. 261 2) Events that have a high idle_time, self_time difference. 262 263 Parameters: 264 length: The number of events to return. 265 """ 266 267 # Find the interval when qd is falling to 0 268 import torch 269 270 queue_depth_list = list(reversed(self.queue_depth_list)) 271 qd_values = [e.queue_depth for e in queue_depth_list] 272 273 bottom_threashold = 0 274 top_threashold = 4 275 decrease_interval = [] 276 i = 0 277 while i < len(qd_values): 278 if qd_values[i] > bottom_threashold: 279 i += 1 280 continue 281 for j in range(i + 1, len(qd_values)): 282 # Find next zero and if the max value between them exceeds 283 # the threshold, then we have a falling interval 284 next_minimum_idx = index_of_first_match( 285 qd_values, lambda x: x <= bottom_threashold, start=j 286 ) 287 peak_idx = argmax(qd_values, start=j, end=next_minimum_idx) 288 289 # if is a valid peak, we add to list and continue 290 if peak_idx is not None and qd_values[peak_idx] >= top_threashold: 291 decrease_interval.append( 292 Interval( 293 queue_depth_list[peak_idx].start, queue_depth_list[i].start 294 ) 295 ) 296 i = next_minimum_idx if next_minimum_idx is not None else i 297 break 298 i += 1 299 # Filter out events that are not in the decrease interval 300 event_list = [ 301 event 302 for event in self.metrics.keys() 303 if event.intervals_overlap(decrease_interval) 304 ] 305 if event_list: 306 self_time = torch.tensor( 307 [self.metrics[event].self_time_ns for event in event_list], 308 dtype=torch.float32, 309 ) 310 idle_time = torch.tensor( 311 [self.metrics[event].fraction_idle_time for event in event_list], 312 dtype=torch.float32, 313 ) 314 normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time) 315 normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time) 316 heuristic_score_list = normalized_gain + 0.6 * normalized_self 317 318 # Sort events by heuristic 319 event_list = [ 320 event 321 for _, event in sorted( 322 zip(heuristic_score_list, event_list), 323 key=operator.itemgetter(0), 324 reverse=True, 325 ) 326 ] 327 event_list = event_list[:length] 328 return event_list 329 330 def get_optimizable_events(self, length: int = 1, print_enable: bool = True): 331 event_list = self.rank_events(length) 332 if not print_enable: 333 return event_list 334 output = "Optimizable events:\n" if event_list else "No events to optimize\n" 335 336 output += "\n".join( 337 [ 338 f"""{'-'*80} 339Event: {event} 340Source code location: {source_code_location(event.event)} 341Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% 342{'-'*80}""" 343 for event in event_list 344 ] 345 ) 346 if print_enable: 347 print(output) 348 return event_list 349 350 351def index_of_first_match(seq, predicate, start=0, end=None): 352 if end is None or end >= len(seq): 353 end = len(seq) 354 for i in range(start, end): 355 if predicate(seq[i]): 356 return i 357 return None 358 359 360def argmax(seq, key=lambda x: x, start=0, end=None): 361 seq = seq[start:end] 362 if len(seq) == 0: 363 return None 364 return seq.index(max(seq, key=key)) + start 365 366 367def source_code_location(event): 368 while event is not None: 369 match = re.search(r"\.py\(.*\)", event.name) 370 if match is None: 371 event = event.parent 372 continue 373 return event.name 374 return "No source code location found" 375 376 377# Provide an OSS workaround for cudagraphs + CUPTI issue 378# https://github.com/pytorch/pytorch/issues/75504 379# TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when 380# we stop supporting older CUDA versions. 381def _init_for_cuda_graphs(): 382 from torch.autograd.profiler import profile 383 384 with profile(): 385 pass 386