xref: /aosp_15_r20/external/executorch/util/activation_memory_profiler.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import json
10import typing
11from dataclasses import dataclass, field
12from typing import List
13
14import executorch.exir.memory as memory
15import torch
16from executorch.exir import ExecutorchProgramManager
17from executorch.exir.memory_planning import get_node_tensor_specs
18from executorch.exir.tensor import num_bytes_from_shape_and_dtype
19from torch.export import ExportedProgram
20
21
22@dataclass
23class Allocation:
24    name: str
25    op_name: str
26    memory_id: int
27    memory_offset: int
28    size_bytes: int
29    fqn: str
30    file_and_line_num: str
31
32
33@dataclass
34class MemoryTimeline:
35    allocations: List[Allocation] = field(default_factory=list)
36
37
38def _get_module_hierarchy(node: torch.fx.Node) -> str:
39    """
40    Get the module hierarchy of the given node.
41    """
42    module_stack = node.meta.get("nn_module_stack")
43    if module_stack is not None:
44        module_values_list = list(module_stack.values())
45        return module_values_list[-1][0]
46    return ""
47
48
49def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]:
50    """
51    Creates a memory timlines, where each step in the timeline is a list of active
52    allocations at that timestep.
53    """
54    nodes = graph.nodes
55    memory_timeline = [None] * len(nodes)
56    for _, node in enumerate(nodes):
57        if node.op == "output":
58            continue
59        if node.target == memory.alloc:
60            continue
61        tensor_specs = get_node_tensor_specs(node)
62        if tensor_specs is None:
63            continue
64        for tensor_spec in tensor_specs:
65            # TODO: Make use of mem_id in the allocation info
66            if tensor_spec is None or tensor_spec.mem_id is None or tensor_spec.const:
67                continue
68            start, end = tensor_spec.lifetime
69            size = num_bytes_from_shape_and_dtype(
70                typing.cast(torch.Size, tensor_spec.shape), tensor_spec.dtype
71            )
72            stack_trace = node.meta.get("stack_trace")
73            fqn = _get_module_hierarchy(node)
74            for j in range(start, end + 1):
75                if memory_timeline[j] is None:
76                    # pyre-ignore
77                    memory_timeline[j] = MemoryTimeline()
78                # pyre-ignore
79                memory_timeline[j].allocations.append(
80                    Allocation(
81                        node.name,
82                        node.target,
83                        tensor_spec.mem_id,
84                        tensor_spec.mem_offset,
85                        size,
86                        fqn,
87                        stack_trace,
88                    )
89                )
90    # pyre-ignore
91    return memory_timeline
92
93
94def _validate_memory_planning_is_done(exported_program: ExportedProgram):
95    """
96    Validate whether the memory planning has been done on the given program.
97    """
98    for node in exported_program.graph.nodes:
99        # If there is at least one memory allocation node, then we know the memory planning has been done.
100        if node.target == memory.alloc:
101            return True
102    return False
103
104
105def generate_memory_trace(
106    executorch_program_manager: ExecutorchProgramManager,
107    chrome_trace_filename: str,
108    enable_memory_offsets: bool = False,
109    method_name: str = "forward",
110):
111    """
112    Generate the memory timeline from the given ExecuTorch program.
113    Args:
114        executorch_program The ExecuTorch program to be analyzed.
115    Returns:
116        Chrome trace in JSON format:
117        Format:
118        Each thread represents a unit of time. Thus to navigate timeline scroll up and down.
119        For each thread, the x axis represents live tensor objects that are normalized according the allocation size.
120    """
121    if not isinstance(executorch_program_manager, ExecutorchProgramManager):
122        raise ValueError(
123            f"generate_memory_trace expects ExecutorchProgramManager instance but got {type(executorch_program_manager)}"
124        )
125
126    exported_program = executorch_program_manager.exported_program(method_name)
127    if not _validate_memory_planning_is_done(exported_program):
128        raise ValueError("Executorch program does not have memory planning.")
129
130    memory_timeline = create_tensor_allocation_info(exported_program.graph)
131    root = {}
132    trace_events = []
133    root["traceEvents"] = trace_events
134
135    tid = 0
136    for memory_timeline_event in memory_timeline:
137        start_time = 0
138        if memory_timeline_event is None:
139            continue
140        for allocation in memory_timeline_event.allocations:
141            e = {}
142            e["name"] = allocation.name
143            e["cat"] = "memory_allocation"
144            e["ph"] = "X"
145            e["ts"] = (
146                int(allocation.memory_offset)
147                if enable_memory_offsets
148                else int(start_time)
149            )
150            allocation_size_kb = allocation.size_bytes
151            e["dur"] = int(allocation_size_kb)
152            e["pid"] = int(allocation.memory_id)
153            e["tid"] = tid
154            e["args"] = {}
155            e["args"]["op_name"] = f"{allocation.op_name}"
156            # ID refers to memory space, typically from 1 to N.
157            # For CPU, everything is allocated on one "space", other backends may have multiple.
158            e["args"]["Memory ID"] = allocation.memory_id
159            e["args"]["fqn"] = f"{allocation.fqn}"
160            e["args"]["source"] = f"{allocation.file_and_line_num}"
161            e["args"]["bytes"] = allocation.size_bytes
162            start_time += allocation_size_kb
163            trace_events.append(e)
164        tid += 1
165
166    json_content: str = json.dumps(root, indent=2)
167
168    with open(chrome_trace_filename, "wb") as json_file:
169        json_file.write(json_content.encode("ascii"))
170