xref: /aosp_15_r20/external/pytorch/torch/distributed/_tools/memory_tracker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3import pickle
4from collections import defaultdict
5from itertools import chain
6from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING
7
8import torch
9import torch.nn as nn
10from torch.utils._python_dispatch import TorchDispatchMode
11
12
13if TYPE_CHECKING:
14    from torch.utils.hooks import RemovableHandle
15
16
17BYTES_PER_MB = 1024 * 1024.0
18
19
20class MemoryProfileDispatchMode(TorchDispatchMode):
21    """Run in ``TorchDispatchMode`` to get memory stats at operator level."""
22
23    def __init__(self, memory_tracker) -> None:
24        self.memory_tracker = memory_tracker
25
26    def __torch_dispatch__(self, func, types, args=..., kwargs=None):
27        rs = func(*args, **kwargs)
28        if func == torch.ops.aten.detach.default:
29            return rs
30        func_name: str = (
31            self.memory_tracker._cur_module_name
32            + "."
33            + func.__name__
34            + "_"
35            + str(self.memory_tracker._operator_names[func.__name__])
36        )
37        self.memory_tracker._operator_names[func.__name__] = (
38            self.memory_tracker._operator_names[func.__name__] + 1
39        )
40        self.memory_tracker._record_memory_stats(func_name)
41
42        return rs
43
44
45class MemoryTracker:
46    """
47    Collect and plot the memory stats at operator level.
48
49    Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``.
50    It also prints a summary for the top 20 operators that generate the most memories.
51
52    Example usage:
53
54        >>> # xdoctest: +SKIP(failing)
55        >>> net.cuda()
56        >>> input = input.cuda()
57
58        >>> mem_tracker = MemoryTracker()
59        >>> mem_tracker.start_monitor(net)
60
61        >>> net.zero_grad(True)
62        >>> loss = net(input)
63        >>> if isinstance(loss, dict):
64        >>>    loss = loss['out']
65        >>> loss.sum().backward()
66        >>> net.zero_grad(set_to_none=True)
67
68        >>> mem_tracker.stop()
69        >>> mem_tracker.summary()
70        >>> mem_tracker.show_traces()
71    """
72
73    def __init__(self) -> None:
74        torch._C._log_api_usage_once("torch.distributed.memory_tracker")
75        self._hooks: List[RemovableHandle] = []
76        self._operator_names: Dict[str, int] = defaultdict(int)
77        self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict()
78        self.memories_active: Dict[int, Dict[str, float]] = defaultdict()
79        self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict()
80        self._markers: Dict[str, int] = defaultdict(int)
81        self._cur_module_name: str = ""
82        self._op_index: int = 0
83        self._num_cuda_retries: int = 0
84
85    @no_type_check
86    def start_monitor(self, root_module: nn.Module) -> None:
87        """
88        Register module hooks and entering ``MemoryProfileDispatchMode``.
89
90        This enables operator level memory stats can be tracked during module runtime.
91        """
92        self._clear_state()
93        root_module.__setattr__("_memory_tracker_is_root", True)
94        for name, m in root_module.named_modules():
95            if m is not root_module:
96                m.__setattr__("_memory_tracker_is_root", False)
97            # fused_proxy_group does not support hooks
98            if ".fused_proxy_grouped_embedding_bag" in name:
99                continue
100            # hook ordering with other hooks added by users is not managed, so
101            # the memory stats tracked here may not completely accurate.
102            h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name))
103            h2 = m.register_forward_hook(self._create_post_forward_hook(name))
104            # it does not work well with jagged tensor somehow, the root cause is not
105            # clear and remove it for now as it does not really capture important info.
106            # h3 = m.register_backward_hook(self._create_backward_hook(name))
107            self._hooks.extend([h1, h2])
108        torch.cuda.empty_cache()
109        assert getattr(self, "profile_mode", None) is None
110        self.profile_mode = MemoryProfileDispatchMode(self)
111        self.profile_mode.__enter__()
112
113    @no_type_check
114    def stop(self) -> None:
115        """
116        Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level.
117
118        Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``.
119        """
120        self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0)
121
122        for h in self._hooks:
123            h.remove()
124        self._hooks.clear()
125        assert getattr(self, "profile_mode", None) is not None
126        self.profile_mode.__exit__(None, None, None)
127        self.profile_mode = None
128
129    @no_type_check
130    def summary(self, top: int = 20) -> None:
131        """
132        Print out the top operators that generate the most memories.
133
134        The number of the top operators can be configured.
135        """
136        op_diff: Dict[str, float] = defaultdict(float)
137        op_name, previous_allocated_memory = self.memories_allocated[0]
138        for i in range(1, self._op_index):
139            op_name, current_allocated_memory = self.memories_allocated[i]
140            op_diff[op_name] = current_allocated_memory - previous_allocated_memory
141            previous_allocated_memory = current_allocated_memory
142
143        print("------------------------------------------------")
144        print(f"The number of cuda retries are: {self._num_cuda_retries}")
145        print(f"Top {top} ops that generates memory are:")
146        for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[
147            :top
148        ]:
149            print(f"{k}: {v}MB")
150        print("------------------------------------------------")
151
152    @no_type_check
153    def show_traces(self, path: str = "") -> None:
154        import matplotlib.pyplot as plt
155
156        def _plot_figure(x, y_values, labels):
157            min_val = min(list(chain(*y_values))) * 0.999
158            max_val = max(list(chain(*y_values))) * 1.001
159            plt.figure()
160            for y, label in zip(y_values, labels):
161                plt.plot(x, y, label=label)
162            plt.xlabel("# Operator Calls")
163            plt.ylabel("Memory (MB)")
164            plt.legend()
165            for marker_name, marker in self._markers.items():
166                if marker_name == "fw_bw_boundary":
167                    plt.plot(
168                        [marker, marker],
169                        [min_val, max_val],
170                        "r",
171                        lw=2,
172                        label=marker_name,
173                    )
174                else:
175                    plt.plot(
176                        [marker, marker],
177                        [min_val, max_val],
178                        "k-",
179                        lw=2,
180                        label=marker_name,
181                    )
182
183        if path != "":
184            self.load(path)
185
186        y_1 = [gb for (name, gb) in self.memories_allocated.values()]
187        y_2 = [gb for (name, gb) in self.memories_active.values()]
188        y_3 = [gb for (name, gb) in self.memories_reserved.values()]
189        x = list(range(len(y_1)))
190        # Split figures when there is big difference between
191        # "reserved_memory" and "allocated_memory" or "active_memory".
192        _plot_figure(
193            x,
194            [list(y_1), list(y_2), list(y_3)],
195            ["allocated_memory", "active_memory", "reserved_memory"],
196        )
197        _plot_figure(x, [list(y_1)], ["allocated_memory"])
198        _plot_figure(x, [list(y_2)], ["active_memory"])
199        _plot_figure(x, [list(y_3)], ["reserved_memory"])
200
201    def save_stats(self, path: str) -> None:
202        """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook."""
203        stats = {
204            "memories_allocated": self.memories_allocated,
205            "memories_active": self.memories_active,
206            "memories_reserved": self.memories_reserved,
207            "markers": self._markers,
208            "num_alloc_retries": self._num_cuda_retries,
209        }
210
211        with open(path, "wb") as f:
212            pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL)
213
214    def load(self, path: str) -> None:
215        """Load the pickled memory stats to plot the traces or print the summary."""
216        with open(path, "rb") as f:
217            stats = pickle.load(f)
218
219        self.memories_allocated = stats["memories_allocated"]
220        self.memories_active = stats["memories_active"]
221        self.memories_reserved = stats["memories_reserved"]
222        self._markers = stats["markers"]
223        self._num_cuda_retries = stats["num_alloc_retries"]
224
225    def _create_pre_forward_hook(self, name: str) -> Callable:
226        """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start."""
227
228        def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
229            self._cur_module_name = f"{name}.forward"
230            if (
231                hasattr(module, "_memory_tracker_is_root")
232                and module._memory_tracker_is_root
233            ):
234                self._add_marker("fw_start")
235
236        return _pre_forward_hook
237
238    def _create_post_forward_hook(self, name: str) -> Callable:
239        """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass."""
240
241        def _post_forward_hook(
242            module: nn.Module,
243            inputs: Sequence[torch.Tensor],
244            outputs: Sequence[torch.Tensor],
245        ) -> None:
246            if (
247                hasattr(module, "_memory_tracker_is_root")
248                and module._memory_tracker_is_root
249            ):
250                self._add_marker("fw_bw_boundary")
251
252        return _post_forward_hook
253
254    def _create_backward_hook(self, name: str) -> Callable:
255        """Insert the current module name with backward prefix for the operator name."""
256
257        def _backward_hook(
258            module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor
259        ) -> None:
260            self._cur_module_name = f"{name}.backward"
261
262        return _backward_hook
263
264    @no_type_check
265    def _record_memory_stats(self, fn_name: str) -> None:
266        """
267        Record current memory allocated, current memory active and current memory reserved.
268
269        The memory stats dict is indexed with ``self._op_index``.
270        """
271        memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB
272        memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB
273        memory_active: float = (
274            torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB
275        )
276        self.memories_allocated[self._op_index] = (fn_name, memory_allocated)
277        self.memories_reserved[self._op_index] = (fn_name, memory_reserved)
278        self.memories_active[self._op_index] = (fn_name, memory_active)
279        self._op_index += 1
280
281    def _add_marker(self, marker_name: str) -> None:
282        """Set the marker's x-axis value."""
283        marker_val = len(self.memories_allocated.values())
284        self._markers[marker_name] = marker_val
285
286    def _clear_state(self) -> None:
287        """Clear states when start_monitor() is called."""
288        self._operator_names.clear()
289        self.memories_allocated.clear()
290        self.memories_active.clear()
291        self.memories_reserved.clear()
292        self._markers.clear()
293        self._cur_module_name = ""
294        self._op_index = 0
295        self._num_cuda_retries = 0
296