xref: /aosp_15_r20/external/pytorch/torch/cuda/_memory_viz.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import pickle
3import sys
4import os
5import io
6import subprocess
7import json
8from functools import lru_cache
9from typing import Any
10from itertools import groupby
11import base64
12import warnings
13import operator
14
15cache = lru_cache(None)
16
17__all__ = ["format_flamegraph", "segments", "memory", "compare"]
18
19def _frame_fmt(f, full_filename=False):
20    i = f['line']
21    fname = f['filename']
22    if not full_filename:
23        fname = fname.split('/')[-1]
24    func = f['name']
25    return f'{fname}:{i}:{func}'
26
27@cache
28def _frame_filter(name, filename):
29    omit_functions = [
30        "unwind::unwind",
31        "CapturedTraceback::gather",
32        "gather_with_cpp",
33        "_start",
34        "__libc_start_main",
35        "PyEval_",
36        "PyObject_",
37        "PyFunction_",
38    ]
39    omit_filenames = [
40        "core/boxing",
41        "/Register",
42        "/Redispatch",
43        "pythonrun.c",
44        "Modules/main.c",
45        "Objects/call.c",
46        "Objects/methodobject.c",
47        "pycore_ceval.h",
48        "ceval.c",
49        "cpython/abstract.h",
50    ]
51    for of in omit_functions:
52        if of in name:
53            return False
54    for of in omit_filenames:
55        if of in filename:
56            return False
57    return True
58
59def _frames_fmt(frames, full_filename=False, reverse=False):
60    if reverse:
61        frames = reversed(frames)
62    return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
63
64def _block_extra_legacy(b):
65    if 'history' in b:
66        frames = b['history'][0].get('frames', [])
67        real_size = b['history'][0]['real_size']
68    else:
69        real_size = b.get('requested_size', b['size'])
70        frames = []
71    return frames, real_size
72
73def _block_extra(b):
74    if 'frames' not in b:
75        # old snapshot format made it more complicated to get frames/allocated size
76        return _block_extra_legacy(b)
77    return b['frames'], b['requested_size']
78
79def format_flamegraph(flamegraph_lines, flamegraph_script=None):
80    if flamegraph_script is None:
81        flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
82    if not os.path.exists(flamegraph_script):
83        import urllib.request
84        print(f"Downloading flamegraph.pl to: {flamegraph_script}")
85        urllib.request.urlretrieve(
86            'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
87        subprocess.check_call(['chmod', '+x', flamegraph_script])
88    args = [flamegraph_script, '--countname', 'bytes']
89    p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
90    assert p.stdin is not None
91    assert p.stdout is not None
92    p.stdin.write(flamegraph_lines)
93    p.stdin.close()
94    result = p.stdout.read()
95    p.stdout.close()
96    p.wait()
97    assert p.wait() == 0
98    return result
99
100def _write_blocks(f, prefix, blocks):
101    def frames_fragment(frames):
102        if not frames:
103            return "<non-python>"
104        return ';'.join(_frames_fmt(frames, reverse=True))
105    for b in blocks:
106        if 'history' not in b:
107            frames, accounted_for_size = _block_extra(b)
108            f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
109        else:
110            accounted_for_size = 0
111            for h in b['history']:
112                sz = h['real_size']
113                accounted_for_size += sz
114                if 'frames' in h:
115                    frames = h['frames']
116                    f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
117                else:
118                    f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
119        gaps = b['size'] - accounted_for_size
120        if gaps:
121            f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
122
123def segments(snapshot, format_flamegraph=format_flamegraph):
124    f = io.StringIO()
125    for seg in snapshot['segments']:
126        prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
127        _write_blocks(f, prefix, seg['blocks'])
128    return format_flamegraph(f.getvalue())
129
130def memory(snapshot, format_flamegraph=format_flamegraph):
131    f = io.StringIO()
132    for seg in snapshot['segments']:
133        prefix = f'stream_{seg["stream"]}'
134        _write_blocks(f, prefix, seg['blocks'])
135    return format_flamegraph(f.getvalue())
136
137def compare(before, after, format_flamegraph=format_flamegraph):
138    def _seg_key(seg):
139        return (seg['address'], seg['total_size'])
140
141    def _seg_info(seg):
142        return f'stream_{seg["stream"]};seg_{seg["address"]}'
143
144    f = io.StringIO()
145
146    before_segs = {_seg_key(seg) for seg in before}
147    after_segs = {_seg_key(seg) for seg in after}
148
149    print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}')
150    print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}')
151
152    for seg in before:
153        if _seg_key(seg) not in after_segs:
154            _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
155
156    for seg in after:
157        if _seg_key(seg) not in before_segs:
158            _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
159
160    return format_flamegraph(f.getvalue())
161
162def _format_size(num):
163    # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
164    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
165        if abs(num) < 1024.0:
166            return f"{num:3.1f}{unit}B"
167        num /= 1024.0
168    return f"{num:.1f}YiB"
169
170class Bytes:
171    def __init__(self, value):
172        self.value = value
173
174    def __add__(self, rhs):
175        return Bytes(self.value + rhs)
176
177    def __repr__(self):
178        return _format_size(self.value)
179
180def calc_active(seg):
181    return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
182
183def _report_free(free_external, free_internal):
184    total = free_external + free_internal
185    suffix = ''
186    if total != 0:
187        pct = (free_internal / total) * 100
188        suffix = f' ({pct:.1f}% internal)'
189    return f'{Bytes(total)}{suffix}'
190
191PAGE_SIZE = 1024 * 1024 * 20
192legend = f"""\
193
194Legend:
195    [a     ] - a segment in the allocator
196     ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
197    a-z: pages filled with a single block's content
198    ' ': page is completely free
199    *: page if completely full with multiple blocks
200    0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
201    (X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
202"""
203
204def segsum(data):
205    r"""Visually reports how the allocator has filled its segments.
206
207    This printout can help debug fragmentation issues since free fragments
208    will appear as gaps in this printout.  The amount of free space is reported
209    for each segment.
210    We distinguish between internal free memory which occurs because the
211    allocator rounds the allocation size, and external free memory, which are
212    the gaps between allocations in a segment.
213    Args:
214        data: snapshot dictionary created from _snapshot()
215    """
216    segments = []
217    out = io.StringIO()
218    out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
219    total_reserved = 0
220    total_allocated = 0
221    free_external = 0
222    free_internal = 0
223    for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
224        total_reserved += seg['total_size']
225
226        seg_free_external = 0
227        seg_free_internal = 0
228        seg_allocated = 0
229        all_ranges = []
230        boffset = 0
231        for b in seg['blocks']:
232            active = b['state'] == 'active_allocated'
233            if active:
234                _, allocated_size = _block_extra(b)
235                all_ranges.append((boffset, allocated_size, True))
236                seg_allocated += allocated_size
237                seg_free_internal += b['size'] - allocated_size
238            else:
239                seg_free_external += b['size']
240
241            boffset += b['size']
242
243        total_allocated += seg_allocated
244        free_external += seg_free_external
245        free_internal += seg_free_internal
246
247        nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
248        occupied = [' ' for _ in range(nseg)]
249        frac = [0.0 for _ in range(nseg)]
250        active_size = 0
251        for i, (start_, size, active) in enumerate(all_ranges):
252            active_size += size
253            finish_ = (start_ + size)
254            start = start_ // PAGE_SIZE
255            finish = (finish_ - 1) // PAGE_SIZE + 1
256            m = chr(ord('a' if active else 'A') + (i % 26))
257            for j in range(start, finish):
258                s = max(start_, j * PAGE_SIZE)
259                e = min(finish_, (j + 1) * PAGE_SIZE)
260                frac[j] += (e - s) / PAGE_SIZE
261                if occupied[j] != ' ':
262                    occupied[j] = '0123456789*'[int(frac[j] * 10)]
263                else:
264                    occupied[j] = m
265        stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
266        body = ''.join(occupied)
267        assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
268        stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
269        if seg['total_size'] >= PAGE_SIZE:
270            out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
271                      f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
272    out.write(f'segments: {len(data["segments"])}\n')
273    out.write(f'total_reserved: {Bytes(total_reserved)}\n')
274    out.write(f'total_allocated: {Bytes(total_allocated)}\n')
275    internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
276    out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
277    out.write(legend)
278    assert free_internal + free_external + total_allocated == total_reserved
279    return out.getvalue()
280
281def trace(data):
282    out = io.StringIO()
283
284    def format(entries):
285        segment_intervals : list = []
286        segment_addr_to_name = {}
287        allocation_addr_to_name = {}
288
289        free_names : list = []
290        next_name = 0
291
292        def _name():
293            nonlocal next_name
294            if free_names:
295                return free_names.pop()
296            r, m = next_name // 26, next_name % 26
297            next_name += 1
298            return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
299
300        def find_segment(addr):
301            for name, saddr, size in segment_intervals:
302                if addr >= saddr and addr < saddr + size:
303                    return name, saddr
304            for i, seg in enumerate(data['segments']):
305                saddr = seg['address']
306                size = seg['allocated_size']
307                if addr >= saddr and addr < saddr + size:
308                    return f'seg_{i}', saddr
309            return None, None
310        count = 0
311        out.write(f'{len(entries)} entries\n')
312
313
314        total_reserved = 0
315        for seg in data['segments']:
316            total_reserved += seg['total_size']
317
318        for count, e in enumerate(entries):
319            if e['action'] == 'alloc':
320                addr, size = e['addr'], e['size']
321                n = _name()
322                seg_name, seg_addr = find_segment(addr)
323                if seg_name is None:
324                    seg_name = "MEM"
325                    offset = addr
326                else:
327                    offset = addr - seg_addr
328                out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
329                allocation_addr_to_name[addr] = (n, size, count)
330                count += size
331            elif e['action'] == 'free_requested':
332                addr, size = e['addr'], e['size']
333                name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
334                out.write(f'del {name} # {Bytes(size)}\n')
335            elif e['action'] == 'free_completed':
336                addr, size = e['addr'], e['size']
337                count -= size
338                name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
339                out.write(f'# free completed for {name} {Bytes(size)}\n')
340                if name in allocation_addr_to_name:
341                    free_names.append(name)
342                    del allocation_addr_to_name[name]
343            elif e['action'] == 'segment_alloc':
344                addr, size = e['addr'], e['size']
345                name = _name()
346                out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
347                segment_intervals.append((name, addr, size))
348                segment_addr_to_name[addr] = name
349            elif e['action'] == 'segment_free':
350                addr, size = e['addr'], e['size']
351                name = segment_addr_to_name.get(addr, addr)
352                out.write(f'cudaFree({name}) # {Bytes(size)}\n')
353                if name in segment_addr_to_name:
354                    free_names.append(name)
355                    del segment_addr_to_name[name]
356            elif e['action'] == 'oom':
357                size = e['size']
358                free = e['device_free']
359                out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
360            else:
361                out.write(f'{e}\n')
362        out.write(f"TOTAL MEM: {Bytes(count)}")
363    for i, d in enumerate(data['device_traces']):
364        if d:
365            out.write(f'Device {i} ----------------\n')
366            format(d)
367    return out.getvalue()
368
369
370_memory_viz_template = r"""
371<!DOCTYPE html>
372<html>
373<head>
374</head>
375<body>
376<script type="module">
377import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
378const local_files = $SNAPSHOT
379add_local_files(local_files, $VIZ_KIND)
380</script>
381</body>
382"""
383
384def _format_viz(data, viz_kind, device):
385    if device is not None:
386        warnings.warn(
387            'device argument is deprecated, plots now contain all device',
388            FutureWarning,
389            stacklevel=3,
390        )
391    buffer = pickle.dumps(data)
392    buffer += b'\x00' * (3 - len(buffer) % 3)
393    # Encode the buffer with base64
394    encoded_buffer = base64.b64encode(buffer).decode('utf-8')
395
396    json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
397    return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
398                               .replace('$SNAPSHOT', json_format)
399
400def trace_plot(data, device=None, plot_segments=False):
401    """Generate a visualization over time of the memory usage recorded by the trace as an html file.
402
403    Args:
404        data: Memory snapshot as generated from torch.cuda.memory._snapshot()
405        device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
406        plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
407                                        Defaults to False.
408
409    Returns:
410        str: HTML of visualization
411    """
412    return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
413
414
415def _profile_to_snapshot(profile):
416    import torch
417    from torch.profiler._memory_profiler import Action, TensorKey
418    from torch._C._profiler import _EventType
419    memory_profile = profile._memory_profile()
420
421    allocation_stacks = {}
422    for event in memory_profile._op_tree.sorted_nodes:
423        if event.tag == _EventType.Allocation:
424            parent = event.parent
425            python_parents = []
426            while parent:
427                if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
428                    python_parents.append(parent)
429                parent = parent.parent
430            key = TensorKey.from_allocation(event.extra_fields)
431
432            # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
433            #              key will be None. I should add some way to identify these, I just haven't yet.
434            if key and event.extra_fields.alloc_size > 0:
435                allocation_stacks[key] = python_parents
436
437
438    device_count = torch.cuda.device_count()
439    snapshot = {
440        'device_traces': [[] for _ in range(device_count + 1)],
441        'segments': [{'device': device,
442                      'address': None,
443                      'total_size': 0,
444                      'stream': 0,
445                      'blocks': []} for device in range(device_count + 1)]
446    }
447
448    def to_device(device):
449        if device.type == 'cuda':
450            return device.index
451        else:
452            return device_count
453
454    def allocate(size, tensor_key, version, during_trace=True):
455        device = to_device(tensor_key.device)
456        addr = tensor_key.storage.ptr
457
458        seg = snapshot['segments'][device]  # type: ignore[index]
459        if seg['address'] is None or seg['address'] > addr:
460            seg['address'] = addr
461        seg['total_size'] = max(seg['total_size'], addr + size)  # record max addr for now, we will make it the size later
462        category = memory_profile._categories.get(tensor_key, version)
463        category = category.name.lower() if category is not None else "unknown"
464        stack = allocation_stacks.get(tensor_key, ())
465        stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
466        r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
467        if during_trace:
468            snapshot['device_traces'][device].append(r)  # type: ignore[index]
469        return r
470
471    def free(alloc, device):
472        for e in ('free_requested', 'free_completed'):
473            snapshot['device_traces'][device].append({'action': e,  # type: ignore[index]
474                                                      'addr': alloc['addr'],
475                                                      'size': alloc['size'],
476                                                      'stream': 0,
477                                                      'frames': alloc['frames']})
478
479    kv_to_elem = {}
480
481
482
483    # create the device trace
484    for time, action, (tensor_key, version), size in memory_profile.timeline:
485        if not isinstance(tensor_key, TensorKey):
486            continue
487        if action == Action.CREATE:
488            kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
489        elif action == Action.DESTROY:
490            free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
491        elif action == Action.INCREMENT_VERSION:
492            free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
493            kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
494        elif action == Action.PREEXISTING:
495            kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
496
497
498    # create the final snapshot state
499    blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
500                     for (tensor_key, version), event in kv_to_elem.items()]
501    for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)):
502        seg = snapshot['segments'][device]  # type: ignore[index]
503        last_addr = seg['address']
504        for _, addr, size, frames in blocks:
505            if last_addr < addr:
506                seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
507            seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
508            last_addr = addr + size
509        if last_addr < seg['total_size']:
510            seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
511
512    snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']]  # type: ignore[attr-defined]
513    for seg in snapshot['segments']:  # type: ignore[attr-defined, name-defined, no-redef]
514        seg['total_size'] -= seg['address']
515        if not seg['blocks']:
516            seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
517
518    return snapshot
519
520def profile_plot(profile, device=None):
521    """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
522
523    Args:
524        profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
525        device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
526
527    Returns:
528        str: HTML of visualization
529    """
530    snapshot = _profile_to_snapshot(profile)
531    return _format_viz(snapshot, 'Active Memory Timeline', device)
532
533
534def segment_plot(data: Any, device=None):
535    return _format_viz(data, 'Allocator State History', device)
536
537if __name__ == "__main__":
538    import os.path
539    thedir = os.path.realpath(os.path.dirname(__file__))
540    if thedir in sys.path:
541        # otherwise we find cuda/random.py as random...
542        sys.path.remove(thedir)
543    import argparse
544
545    fn_name = 'torch.cuda.memory._snapshot()'
546    pickled = f'pickled memory statistics from {fn_name}'
547    parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
548
549    subparsers = parser.add_subparsers(dest='action')
550
551    def _output(p):
552        p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
553
554    description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
555    stats_a = subparsers.add_parser('stats', description=description)
556    stats_a.add_argument('input', help=pickled)
557
558    description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
559    trace_a = subparsers.add_parser('trace', description=description)
560    trace_a.add_argument('input', help=pickled)
561
562    description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
563    segments_a = subparsers.add_parser('segments', description=description)
564    segments_a.add_argument('input', help=pickled)
565    _output(segments_a)
566
567    description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
568    memory_a = subparsers.add_parser('memory', description=description)
569    memory_a.add_argument('input', help=pickled)
570    _output(memory_a)
571
572    description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
573        'or removed between two different memorys snapshots.'
574    compare_a = subparsers.add_parser('compare', description=description)
575    compare_a.add_argument('before', help=pickled)
576    compare_a.add_argument('after', help=pickled)
577    _output(compare_a)
578
579    plots = (
580        ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
581        ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
582    )
583    for cmd, description in plots:
584        trace_plot_a = subparsers.add_parser(cmd, description=description)
585        trace_plot_a.add_argument('input', help=pickled)
586        help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
587        trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
588        help = 'path to save the visualization(default: output.html)'
589        trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
590        if cmd == "trace_plot":
591            help = 'visualize change to segments rather than individual allocations'
592            trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
593
594
595    args = parser.parse_args()
596
597    def _read(name):
598        if name == '-':
599            f = sys.stdin.buffer
600        else:
601            f = open(name, 'rb')
602        data = pickle.load(f)
603        if isinstance(data, list):  # segments only...
604            data = {'segments': data, 'traces': []}
605        return data
606
607    def _write(name, data):
608        with open(name, 'w') as f:
609            f.write(data)
610
611    if args.action == 'segments':
612        data = _read(args.input)
613        _write(args.output, segments(data))
614    elif args.action == 'memory':
615        data = _read(args.input)
616        _write(args.output, memory(data))
617    elif args.action == 'stats':
618        data = _read(args.input)
619        print(segsum(data))
620    elif args.action == 'trace':
621        data = _read(args.input)
622        print(trace(data))
623    elif args.action == 'compare':
624        before = _read(args.before)
625        after = _read(args.after)
626        _write(args.output, compare(before, after))
627    elif args.action == 'trace_plot':
628        data = _read(args.input)
629        _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
630    elif args.action == 'segment_plot':
631        data = _read(args.input)
632        _write(args.output, segment_plot(data, device=args.device))
633