1# mypy: allow-untyped-defs 2import itertools 3import warnings 4from typing_extensions import deprecated 5 6import torch 7import torch.cuda 8from torch.autograd import ( 9 _disable_profiler_legacy, 10 _enable_profiler_legacy, 11 DeviceType, 12 ProfilerConfig, 13 ProfilerState, 14) 15from torch.autograd.profiler_util import ( 16 _filter_name, 17 _filter_stack_entry, 18 _rewrite_name, 19 EventList, 20 FunctionEvent, 21 MEMORY_EVENT_NAME, 22) 23 24 25__all__ = ["profile"] 26 27 28@deprecated( 29 "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. " 30 "Please use `torch.profiler` instead.", 31 category=None, # TODO: change to `FutureWarning` 32) 33class profile: 34 """DEPRECATED: use torch.profiler instead.""" 35 36 def __init__( 37 self, 38 enabled=True, 39 *, 40 use_cuda=False, 41 record_shapes=False, 42 with_flops=False, 43 profile_memory=False, 44 with_stack=False, 45 with_modules=False, 46 ): 47 self.enabled: bool = enabled 48 if not self.enabled: 49 return 50 self.use_cuda = use_cuda 51 self.function_events = None 52 self.entered = False 53 self.record_shapes = record_shapes 54 self.with_flops = with_flops 55 self.record_shapes |= self.with_flops 56 self.profile_memory = profile_memory 57 self.with_stack = with_stack 58 self.with_modules = with_modules 59 60 if self.use_cuda and not torch.cuda.is_available(): 61 warnings.warn( 62 "CUDA is not available, disabling CUDA profiling", 63 stacklevel=2, 64 ) 65 self.use_cuda = False 66 67 if self.use_cuda: 68 self.profiler_kind = ProfilerState.CUDA 69 else: 70 self.profiler_kind = ProfilerState.CPU 71 72 def config(self): 73 return ProfilerConfig( 74 self.profiler_kind, 75 self.record_shapes, 76 self.profile_memory, 77 self.with_stack, 78 self.with_flops, 79 self.with_modules, 80 # avoid exposing _ExperimentalConfig this in legacy public API 81 torch._C._profiler._ExperimentalConfig(), 82 ) 83 84 def __enter__(self): 85 if not self.enabled: 86 return 87 if self.entered: 88 raise RuntimeError("Profiler context manager is not reentrant") 89 self.entered = True 90 self._start_trace() 91 return self 92 93 def _start_trace(self): 94 _enable_profiler_legacy(self.config()) 95 96 def __exit__(self, exc_type, exc_val, exc_tb): 97 if not self.enabled: 98 return 99 if self.use_cuda: 100 torch.cuda.synchronize() 101 102 records = _disable_profiler_legacy() 103 parsed_results = _parse_legacy_records(records) 104 self.function_events = EventList( 105 parsed_results, 106 use_device="cuda" if self.use_cuda else None, 107 profile_memory=self.profile_memory, 108 with_flops=self.with_flops, 109 ) 110 self.function_events._build_tree() 111 return False 112 113 def __repr__(self): 114 if self.function_events is None: 115 return "<unfinished profiler_legacy.profile>" 116 return repr(self.function_events) 117 118 def __str__(self): 119 if self.function_events is None: 120 return "<unfinished profile.profiler_legacy.profile>" 121 return str(self.function_events) 122 123 def _check_finish(self): 124 if self.function_events is None: 125 raise RuntimeError("Profiler didn't finish running") 126 127 def table( 128 self, 129 sort_by=None, 130 row_limit=100, 131 max_src_column_width=75, 132 max_name_column_width=55, 133 max_shapes_column_width=80, 134 header=None, 135 top_level_events_only=False, 136 ): 137 self._check_finish() 138 assert self.function_events is not None 139 return self.function_events.table( 140 sort_by=sort_by, 141 row_limit=row_limit, 142 max_src_column_width=max_src_column_width, 143 max_name_column_width=max_name_column_width, 144 max_shapes_column_width=max_shapes_column_width, 145 header=header, 146 top_level_events_only=top_level_events_only, 147 ) 148 149 table.__doc__ = EventList.table.__doc__ 150 151 def export_chrome_trace(self, path): 152 self._check_finish() 153 assert self.function_events is not None 154 return self.function_events.export_chrome_trace(path) 155 156 export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ 157 158 def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): 159 self._check_finish() 160 assert self.function_events is not None, "Expected profiling results" 161 assert self.with_stack, "export_stacks() requires with_stack=True" 162 return self.function_events.export_stacks(path, metric) 163 164 def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): 165 self._check_finish() 166 assert self.function_events is not None, "Expected profiling results" 167 return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) 168 169 key_averages.__doc__ = EventList.key_averages.__doc__ 170 171 def total_average(self): 172 self._check_finish() 173 assert self.function_events is not None, "Expected profiling results" 174 return self.function_events.total_average() 175 176 total_average.__doc__ = EventList.total_average.__doc__ 177 178 @property 179 def self_cpu_time_total(self): 180 """Return CPU time as the sum of self times across all events.""" 181 self._check_finish() 182 assert self.function_events is not None 183 return self.function_events.self_cpu_time_total 184 185 186def _parse_legacy_records(thread_records): 187 def _get_record_key(record): 188 """Return a tuple for correlating start and end records in `_parse_legacy_records`.""" 189 return (record.handle(), record.node_id()) 190 191 next_id = 0 192 start_record = None 193 functions = [] 194 record_stack = [] 195 196 # '__start_profile' is not guaranteed to be first, so we must find it here 197 for record in itertools.chain.from_iterable(thread_records): 198 name = record.name() 199 if start_record is None and name == "__start_profile": 200 start_record = record 201 202 assert start_record is not None and not start_record.is_remote() 203 204 for thread_record_list in thread_records: 205 # accumulated memory allocations per handle 206 cpu_memory_allocs = {} 207 cuda_memory_allocs = {} 208 # ranges per handle 209 range_starts = {} 210 211 filtered_handles = set() 212 prev_record = None 213 for record in thread_record_list: 214 record_key = _get_record_key(record) 215 if _filter_name(record.name()) or record_key in filtered_handles: 216 filtered_handles.add(record_key) 217 continue 218 219 if record.kind() == "push": 220 # workaround to reduce double logging from operator 221 # wrappers and redispatch 222 if prev_record is not None: 223 duplicate = ( 224 prev_record.name() == record.name() 225 and prev_record.kind() == record.kind() 226 and prev_record.node_id() == record.node_id() 227 ) 228 if duplicate: 229 filtered_handles.add(record_key) 230 continue 231 232 range_starts[record_key] = record 233 cpu_memory_allocs[record_key] = 0 234 cuda_memory_allocs[record_key] = 0 235 elif record.kind() == "pop": 236 assert ( 237 record_key in range_starts 238 ), f"""Expected record with key {record_key} to exist in range_starts. 239 This means that the pop event did not have a corresponding push.""" 240 241 start = range_starts[record_key] 242 243 cpu_memory_usage = cpu_memory_allocs[record_key] 244 cuda_memory_usage = cuda_memory_allocs[record_key] 245 is_async = start.is_async() or (start.thread_id() != record.thread_id()) 246 is_remote_event = record.is_remote() 247 start_flops = start.flops() 248 249 fe = FunctionEvent( 250 id=record.handle(), 251 node_id=record.node_id(), 252 name=_rewrite_name(name=start.name(), with_wildcard=True), 253 trace_name=_rewrite_name(name=start.name(), with_wildcard=False), 254 thread=start.thread_id(), 255 start_us=start_record.cpu_elapsed_us(start), 256 end_us=start_record.cpu_elapsed_us(record), 257 fwd_thread=start.fwd_thread_id(), 258 input_shapes=start.shapes(), 259 stack=[ 260 entry for entry in start.stack() if _filter_stack_entry(entry) 261 ], 262 scope=start.scope(), 263 use_device="cuda" if start.has_cuda() else None, 264 cpu_memory_usage=cpu_memory_usage, 265 device_memory_usage=cuda_memory_usage, 266 is_async=is_async, 267 is_remote=is_remote_event, 268 sequence_nr=start.sequence_nr(), 269 device_type=DeviceType.CPU, 270 is_legacy=True, 271 flops=start_flops, 272 ) 273 # note: async events have only cpu total time 274 if not is_async and start.has_cuda(): 275 duration = start.cuda_elapsed_us(record) 276 if duration > 0: 277 fe.append_kernel(start.name(), start.device(), duration) 278 functions.append(fe) 279 del range_starts[record_key] 280 del cpu_memory_allocs[record_key] 281 del cuda_memory_allocs[record_key] 282 elif record.kind() == "memory_alloc": 283 num_open_handles_cpu = len(cpu_memory_allocs) 284 num_open_handles_cuda = len(cuda_memory_allocs) 285 assert num_open_handles_cpu == num_open_handles_cuda 286 for handle in cpu_memory_allocs.keys(): 287 cpu_memory_allocs[handle] += record.cpu_memory_usage() 288 for handle in cuda_memory_allocs.keys(): 289 cuda_memory_allocs[handle] += record.cuda_memory_usage() 290 if num_open_handles_cpu == 0: 291 # output event as a top-level memory event 292 fe = FunctionEvent( 293 id=0, 294 name=MEMORY_EVENT_NAME, 295 trace_name=None, 296 thread=0, 297 start_us=0, 298 end_us=0, 299 stack=[], 300 cpu_memory_usage=record.cpu_memory_usage(), 301 device_memory_usage=record.cuda_memory_usage(), 302 is_legacy=True, 303 ) 304 functions.append(fe) 305 prev_record = record 306 307 # Sort functions by start time then by end time ascending. 308 # This ensures that--in the case of nested events which 309 # have the same start time (which may happen due to the 310 # granularity of the given clock tick)--we always show 311 # the outermost nested call first. This adds stability 312 # in how FunctionEvents appear 313 functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) 314 return functions 315