1# mypy: allow-untyped-defs 2import pickle 3import sys 4import os 5import io 6import subprocess 7import json 8from functools import lru_cache 9from typing import Any 10from itertools import groupby 11import base64 12import warnings 13import operator 14 15cache = lru_cache(None) 16 17__all__ = ["format_flamegraph", "segments", "memory", "compare"] 18 19def _frame_fmt(f, full_filename=False): 20 i = f['line'] 21 fname = f['filename'] 22 if not full_filename: 23 fname = fname.split('/')[-1] 24 func = f['name'] 25 return f'{fname}:{i}:{func}' 26 27@cache 28def _frame_filter(name, filename): 29 omit_functions = [ 30 "unwind::unwind", 31 "CapturedTraceback::gather", 32 "gather_with_cpp", 33 "_start", 34 "__libc_start_main", 35 "PyEval_", 36 "PyObject_", 37 "PyFunction_", 38 ] 39 omit_filenames = [ 40 "core/boxing", 41 "/Register", 42 "/Redispatch", 43 "pythonrun.c", 44 "Modules/main.c", 45 "Objects/call.c", 46 "Objects/methodobject.c", 47 "pycore_ceval.h", 48 "ceval.c", 49 "cpython/abstract.h", 50 ] 51 for of in omit_functions: 52 if of in name: 53 return False 54 for of in omit_filenames: 55 if of in filename: 56 return False 57 return True 58 59def _frames_fmt(frames, full_filename=False, reverse=False): 60 if reverse: 61 frames = reversed(frames) 62 return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])] 63 64def _block_extra_legacy(b): 65 if 'history' in b: 66 frames = b['history'][0].get('frames', []) 67 real_size = b['history'][0]['real_size'] 68 else: 69 real_size = b.get('requested_size', b['size']) 70 frames = [] 71 return frames, real_size 72 73def _block_extra(b): 74 if 'frames' not in b: 75 # old snapshot format made it more complicated to get frames/allocated size 76 return _block_extra_legacy(b) 77 return b['frames'], b['requested_size'] 78 79def format_flamegraph(flamegraph_lines, flamegraph_script=None): 80 if flamegraph_script is None: 81 flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl' 82 if not os.path.exists(flamegraph_script): 83 import urllib.request 84 print(f"Downloading flamegraph.pl to: {flamegraph_script}") 85 urllib.request.urlretrieve( 86 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script) 87 subprocess.check_call(['chmod', '+x', flamegraph_script]) 88 args = [flamegraph_script, '--countname', 'bytes'] 89 p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8') 90 assert p.stdin is not None 91 assert p.stdout is not None 92 p.stdin.write(flamegraph_lines) 93 p.stdin.close() 94 result = p.stdout.read() 95 p.stdout.close() 96 p.wait() 97 assert p.wait() == 0 98 return result 99 100def _write_blocks(f, prefix, blocks): 101 def frames_fragment(frames): 102 if not frames: 103 return "<non-python>" 104 return ';'.join(_frames_fmt(frames, reverse=True)) 105 for b in blocks: 106 if 'history' not in b: 107 frames, accounted_for_size = _block_extra(b) 108 f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n') 109 else: 110 accounted_for_size = 0 111 for h in b['history']: 112 sz = h['real_size'] 113 accounted_for_size += sz 114 if 'frames' in h: 115 frames = h['frames'] 116 f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') 117 else: 118 f.write(f'{prefix};{b["state"]};<no-context> {sz}\n') 119 gaps = b['size'] - accounted_for_size 120 if gaps: 121 f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n') 122 123def segments(snapshot, format_flamegraph=format_flamegraph): 124 f = io.StringIO() 125 for seg in snapshot['segments']: 126 prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' 127 _write_blocks(f, prefix, seg['blocks']) 128 return format_flamegraph(f.getvalue()) 129 130def memory(snapshot, format_flamegraph=format_flamegraph): 131 f = io.StringIO() 132 for seg in snapshot['segments']: 133 prefix = f'stream_{seg["stream"]}' 134 _write_blocks(f, prefix, seg['blocks']) 135 return format_flamegraph(f.getvalue()) 136 137def compare(before, after, format_flamegraph=format_flamegraph): 138 def _seg_key(seg): 139 return (seg['address'], seg['total_size']) 140 141 def _seg_info(seg): 142 return f'stream_{seg["stream"]};seg_{seg["address"]}' 143 144 f = io.StringIO() 145 146 before_segs = {_seg_key(seg) for seg in before} 147 after_segs = {_seg_key(seg) for seg in after} 148 149 print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}') 150 print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}') 151 152 for seg in before: 153 if _seg_key(seg) not in after_segs: 154 _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks']) 155 156 for seg in after: 157 if _seg_key(seg) not in before_segs: 158 _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks']) 159 160 return format_flamegraph(f.getvalue()) 161 162def _format_size(num): 163 # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size 164 for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 165 if abs(num) < 1024.0: 166 return f"{num:3.1f}{unit}B" 167 num /= 1024.0 168 return f"{num:.1f}YiB" 169 170class Bytes: 171 def __init__(self, value): 172 self.value = value 173 174 def __add__(self, rhs): 175 return Bytes(self.value + rhs) 176 177 def __repr__(self): 178 return _format_size(self.value) 179 180def calc_active(seg): 181 return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated') 182 183def _report_free(free_external, free_internal): 184 total = free_external + free_internal 185 suffix = '' 186 if total != 0: 187 pct = (free_internal / total) * 100 188 suffix = f' ({pct:.1f}% internal)' 189 return f'{Bytes(total)}{suffix}' 190 191PAGE_SIZE = 1024 * 1024 * 20 192legend = f"""\ 193 194Legend: 195 [a ] - a segment in the allocator 196 ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment 197 a-z: pages filled with a single block's content 198 ' ': page is completely free 199 *: page if completely full with multiple blocks 200 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) 201 (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. 202""" 203 204def segsum(data): 205 r"""Visually reports how the allocator has filled its segments. 206 207 This printout can help debug fragmentation issues since free fragments 208 will appear as gaps in this printout. The amount of free space is reported 209 for each segment. 210 We distinguish between internal free memory which occurs because the 211 allocator rounds the allocation size, and external free memory, which are 212 the gaps between allocations in a segment. 213 Args: 214 data: snapshot dictionary created from _snapshot() 215 """ 216 segments = [] 217 out = io.StringIO() 218 out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") 219 total_reserved = 0 220 total_allocated = 0 221 free_external = 0 222 free_internal = 0 223 for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))): 224 total_reserved += seg['total_size'] 225 226 seg_free_external = 0 227 seg_free_internal = 0 228 seg_allocated = 0 229 all_ranges = [] 230 boffset = 0 231 for b in seg['blocks']: 232 active = b['state'] == 'active_allocated' 233 if active: 234 _, allocated_size = _block_extra(b) 235 all_ranges.append((boffset, allocated_size, True)) 236 seg_allocated += allocated_size 237 seg_free_internal += b['size'] - allocated_size 238 else: 239 seg_free_external += b['size'] 240 241 boffset += b['size'] 242 243 total_allocated += seg_allocated 244 free_external += seg_free_external 245 free_internal += seg_free_internal 246 247 nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1 248 occupied = [' ' for _ in range(nseg)] 249 frac = [0.0 for _ in range(nseg)] 250 active_size = 0 251 for i, (start_, size, active) in enumerate(all_ranges): 252 active_size += size 253 finish_ = (start_ + size) 254 start = start_ // PAGE_SIZE 255 finish = (finish_ - 1) // PAGE_SIZE + 1 256 m = chr(ord('a' if active else 'A') + (i % 26)) 257 for j in range(start, finish): 258 s = max(start_, j * PAGE_SIZE) 259 e = min(finish_, (j + 1) * PAGE_SIZE) 260 frac[j] += (e - s) / PAGE_SIZE 261 if occupied[j] != ' ': 262 occupied[j] = '0123456789*'[int(frac[j] * 10)] 263 else: 264 occupied[j] = m 265 stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}' 266 body = ''.join(occupied) 267 assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size'] 268 stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else '' 269 if seg['total_size'] >= PAGE_SIZE: 270 out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, ' 271 f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n') 272 out.write(f'segments: {len(data["segments"])}\n') 273 out.write(f'total_reserved: {Bytes(total_reserved)}\n') 274 out.write(f'total_allocated: {Bytes(total_allocated)}\n') 275 internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else '' 276 out.write(f'total_free: {_report_free(free_external, free_internal)}\n') 277 out.write(legend) 278 assert free_internal + free_external + total_allocated == total_reserved 279 return out.getvalue() 280 281def trace(data): 282 out = io.StringIO() 283 284 def format(entries): 285 segment_intervals : list = [] 286 segment_addr_to_name = {} 287 allocation_addr_to_name = {} 288 289 free_names : list = [] 290 next_name = 0 291 292 def _name(): 293 nonlocal next_name 294 if free_names: 295 return free_names.pop() 296 r, m = next_name // 26, next_name % 26 297 next_name += 1 298 return f'{chr(ord("a") + m)}{"" if r == 0 else r}' 299 300 def find_segment(addr): 301 for name, saddr, size in segment_intervals: 302 if addr >= saddr and addr < saddr + size: 303 return name, saddr 304 for i, seg in enumerate(data['segments']): 305 saddr = seg['address'] 306 size = seg['allocated_size'] 307 if addr >= saddr and addr < saddr + size: 308 return f'seg_{i}', saddr 309 return None, None 310 count = 0 311 out.write(f'{len(entries)} entries\n') 312 313 314 total_reserved = 0 315 for seg in data['segments']: 316 total_reserved += seg['total_size'] 317 318 for count, e in enumerate(entries): 319 if e['action'] == 'alloc': 320 addr, size = e['addr'], e['size'] 321 n = _name() 322 seg_name, seg_addr = find_segment(addr) 323 if seg_name is None: 324 seg_name = "MEM" 325 offset = addr 326 else: 327 offset = addr - seg_addr 328 out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n') 329 allocation_addr_to_name[addr] = (n, size, count) 330 count += size 331 elif e['action'] == 'free_requested': 332 addr, size = e['addr'], e['size'] 333 name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) 334 out.write(f'del {name} # {Bytes(size)}\n') 335 elif e['action'] == 'free_completed': 336 addr, size = e['addr'], e['size'] 337 count -= size 338 name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) 339 out.write(f'# free completed for {name} {Bytes(size)}\n') 340 if name in allocation_addr_to_name: 341 free_names.append(name) 342 del allocation_addr_to_name[name] 343 elif e['action'] == 'segment_alloc': 344 addr, size = e['addr'], e['size'] 345 name = _name() 346 out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n') 347 segment_intervals.append((name, addr, size)) 348 segment_addr_to_name[addr] = name 349 elif e['action'] == 'segment_free': 350 addr, size = e['addr'], e['size'] 351 name = segment_addr_to_name.get(addr, addr) 352 out.write(f'cudaFree({name}) # {Bytes(size)}\n') 353 if name in segment_addr_to_name: 354 free_names.append(name) 355 del segment_addr_to_name[name] 356 elif e['action'] == 'oom': 357 size = e['size'] 358 free = e['device_free'] 359 out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n') 360 else: 361 out.write(f'{e}\n') 362 out.write(f"TOTAL MEM: {Bytes(count)}") 363 for i, d in enumerate(data['device_traces']): 364 if d: 365 out.write(f'Device {i} ----------------\n') 366 format(d) 367 return out.getvalue() 368 369 370_memory_viz_template = r""" 371<!DOCTYPE html> 372<html> 373<head> 374</head> 375<body> 376<script type="module"> 377import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js" 378const local_files = $SNAPSHOT 379add_local_files(local_files, $VIZ_KIND) 380</script> 381</body> 382""" 383 384def _format_viz(data, viz_kind, device): 385 if device is not None: 386 warnings.warn( 387 'device argument is deprecated, plots now contain all device', 388 FutureWarning, 389 stacklevel=3, 390 ) 391 buffer = pickle.dumps(data) 392 buffer += b'\x00' * (3 - len(buffer) % 3) 393 # Encode the buffer with base64 394 encoded_buffer = base64.b64encode(buffer).decode('utf-8') 395 396 json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}]) 397 return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \ 398 .replace('$SNAPSHOT', json_format) 399 400def trace_plot(data, device=None, plot_segments=False): 401 """Generate a visualization over time of the memory usage recorded by the trace as an html file. 402 403 Args: 404 data: Memory snapshot as generated from torch.cuda.memory._snapshot() 405 device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. 406 plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. 407 Defaults to False. 408 409 Returns: 410 str: HTML of visualization 411 """ 412 return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device) 413 414 415def _profile_to_snapshot(profile): 416 import torch 417 from torch.profiler._memory_profiler import Action, TensorKey 418 from torch._C._profiler import _EventType 419 memory_profile = profile._memory_profile() 420 421 allocation_stacks = {} 422 for event in memory_profile._op_tree.sorted_nodes: 423 if event.tag == _EventType.Allocation: 424 parent = event.parent 425 python_parents = [] 426 while parent: 427 if parent.tag in (_EventType.PyCall, _EventType.PyCCall): 428 python_parents.append(parent) 429 parent = parent.parent 430 key = TensorKey.from_allocation(event.extra_fields) 431 432 # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) 433 # key will be None. I should add some way to identify these, I just haven't yet. 434 if key and event.extra_fields.alloc_size > 0: 435 allocation_stacks[key] = python_parents 436 437 438 device_count = torch.cuda.device_count() 439 snapshot = { 440 'device_traces': [[] for _ in range(device_count + 1)], 441 'segments': [{'device': device, 442 'address': None, 443 'total_size': 0, 444 'stream': 0, 445 'blocks': []} for device in range(device_count + 1)] 446 } 447 448 def to_device(device): 449 if device.type == 'cuda': 450 return device.index 451 else: 452 return device_count 453 454 def allocate(size, tensor_key, version, during_trace=True): 455 device = to_device(tensor_key.device) 456 addr = tensor_key.storage.ptr 457 458 seg = snapshot['segments'][device] # type: ignore[index] 459 if seg['address'] is None or seg['address'] > addr: 460 seg['address'] = addr 461 seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later 462 category = memory_profile._categories.get(tensor_key, version) 463 category = category.name.lower() if category is not None else "unknown" 464 stack = allocation_stacks.get(tensor_key, ()) 465 stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack] 466 r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category} 467 if during_trace: 468 snapshot['device_traces'][device].append(r) # type: ignore[index] 469 return r 470 471 def free(alloc, device): 472 for e in ('free_requested', 'free_completed'): 473 snapshot['device_traces'][device].append({'action': e, # type: ignore[index] 474 'addr': alloc['addr'], 475 'size': alloc['size'], 476 'stream': 0, 477 'frames': alloc['frames']}) 478 479 kv_to_elem = {} 480 481 482 483 # create the device trace 484 for time, action, (tensor_key, version), size in memory_profile.timeline: 485 if not isinstance(tensor_key, TensorKey): 486 continue 487 if action == Action.CREATE: 488 kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) 489 elif action == Action.DESTROY: 490 free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) 491 elif action == Action.INCREMENT_VERSION: 492 free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) 493 kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1) 494 elif action == Action.PREEXISTING: 495 kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False) 496 497 498 # create the final snapshot state 499 blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames']) 500 for (tensor_key, version), event in kv_to_elem.items()] 501 for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): 502 seg = snapshot['segments'][device] # type: ignore[index] 503 last_addr = seg['address'] 504 for _, addr, size, frames in blocks: 505 if last_addr < addr: 506 seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'}) 507 seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames}) 508 last_addr = addr + size 509 if last_addr < seg['total_size']: 510 seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'}) 511 512 snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined] 513 for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef] 514 seg['total_size'] -= seg['address'] 515 if not seg['blocks']: 516 seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'}) 517 518 return snapshot 519 520def profile_plot(profile, device=None): 521 """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. 522 523 Args: 524 profile: profile as generated by `torch.profiler.profile(profile_memory=True)` 525 device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. 526 527 Returns: 528 str: HTML of visualization 529 """ 530 snapshot = _profile_to_snapshot(profile) 531 return _format_viz(snapshot, 'Active Memory Timeline', device) 532 533 534def segment_plot(data: Any, device=None): 535 return _format_viz(data, 'Allocator State History', device) 536 537if __name__ == "__main__": 538 import os.path 539 thedir = os.path.realpath(os.path.dirname(__file__)) 540 if thedir in sys.path: 541 # otherwise we find cuda/random.py as random... 542 sys.path.remove(thedir) 543 import argparse 544 545 fn_name = 'torch.cuda.memory._snapshot()' 546 pickled = f'pickled memory statistics from {fn_name}' 547 parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}') 548 549 subparsers = parser.add_subparsers(dest='action') 550 551 def _output(p): 552 p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)') 553 554 description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.' 555 stats_a = subparsers.add_parser('stats', description=description) 556 stats_a.add_argument('input', help=pickled) 557 558 description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.' 559 trace_a = subparsers.add_parser('trace', description=description) 560 trace_a.add_argument('input', help=pickled) 561 562 description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)' 563 segments_a = subparsers.add_parser('segments', description=description) 564 segments_a.add_argument('input', help=pickled) 565 _output(segments_a) 566 567 description = "Generate a flamegraph the program locations contributing to CUDA memory usage." 568 memory_a = subparsers.add_parser('memory', description=description) 569 memory_a.add_argument('input', help=pickled) 570 _output(memory_a) 571 572 description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \ 573 'or removed between two different memorys snapshots.' 574 compare_a = subparsers.add_parser('compare', description=description) 575 compare_a.add_argument('before', help=pickled) 576 compare_a.add_argument('after', help=pickled) 577 _output(compare_a) 578 579 plots = ( 580 ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."), 581 ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.") 582 ) 583 for cmd, description in plots: 584 trace_plot_a = subparsers.add_parser(cmd, description=description) 585 trace_plot_a.add_argument('input', help=pickled) 586 help = 'visualize trace from this device (default: chooses the only device with trace info or errors)' 587 trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help) 588 help = 'path to save the visualization(default: output.html)' 589 trace_plot_a.add_argument('-o', '--output', default='output.html', help=help) 590 if cmd == "trace_plot": 591 help = 'visualize change to segments rather than individual allocations' 592 trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help) 593 594 595 args = parser.parse_args() 596 597 def _read(name): 598 if name == '-': 599 f = sys.stdin.buffer 600 else: 601 f = open(name, 'rb') 602 data = pickle.load(f) 603 if isinstance(data, list): # segments only... 604 data = {'segments': data, 'traces': []} 605 return data 606 607 def _write(name, data): 608 with open(name, 'w') as f: 609 f.write(data) 610 611 if args.action == 'segments': 612 data = _read(args.input) 613 _write(args.output, segments(data)) 614 elif args.action == 'memory': 615 data = _read(args.input) 616 _write(args.output, memory(data)) 617 elif args.action == 'stats': 618 data = _read(args.input) 619 print(segsum(data)) 620 elif args.action == 'trace': 621 data = _read(args.input) 622 print(trace(data)) 623 elif args.action == 'compare': 624 before = _read(args.before) 625 after = _read(args.after) 626 _write(args.output, compare(before, after)) 627 elif args.action == 'trace_plot': 628 data = _read(args.input) 629 _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments)) 630 elif args.action == 'segment_plot': 631 data = _read(args.input) 632 _write(args.output, segment_plot(data, device=args.device)) 633