1# mypy: allow-untyped-defs 2import operator 3import pickle 4from collections import defaultdict 5from itertools import chain 6from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING 7 8import torch 9import torch.nn as nn 10from torch.utils._python_dispatch import TorchDispatchMode 11 12 13if TYPE_CHECKING: 14 from torch.utils.hooks import RemovableHandle 15 16 17BYTES_PER_MB = 1024 * 1024.0 18 19 20class MemoryProfileDispatchMode(TorchDispatchMode): 21 """Run in ``TorchDispatchMode`` to get memory stats at operator level.""" 22 23 def __init__(self, memory_tracker) -> None: 24 self.memory_tracker = memory_tracker 25 26 def __torch_dispatch__(self, func, types, args=..., kwargs=None): 27 rs = func(*args, **kwargs) 28 if func == torch.ops.aten.detach.default: 29 return rs 30 func_name: str = ( 31 self.memory_tracker._cur_module_name 32 + "." 33 + func.__name__ 34 + "_" 35 + str(self.memory_tracker._operator_names[func.__name__]) 36 ) 37 self.memory_tracker._operator_names[func.__name__] = ( 38 self.memory_tracker._operator_names[func.__name__] + 1 39 ) 40 self.memory_tracker._record_memory_stats(func_name) 41 42 return rs 43 44 45class MemoryTracker: 46 """ 47 Collect and plot the memory stats at operator level. 48 49 Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``. 50 It also prints a summary for the top 20 operators that generate the most memories. 51 52 Example usage: 53 54 >>> # xdoctest: +SKIP(failing) 55 >>> net.cuda() 56 >>> input = input.cuda() 57 58 >>> mem_tracker = MemoryTracker() 59 >>> mem_tracker.start_monitor(net) 60 61 >>> net.zero_grad(True) 62 >>> loss = net(input) 63 >>> if isinstance(loss, dict): 64 >>> loss = loss['out'] 65 >>> loss.sum().backward() 66 >>> net.zero_grad(set_to_none=True) 67 68 >>> mem_tracker.stop() 69 >>> mem_tracker.summary() 70 >>> mem_tracker.show_traces() 71 """ 72 73 def __init__(self) -> None: 74 torch._C._log_api_usage_once("torch.distributed.memory_tracker") 75 self._hooks: List[RemovableHandle] = [] 76 self._operator_names: Dict[str, int] = defaultdict(int) 77 self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict() 78 self.memories_active: Dict[int, Dict[str, float]] = defaultdict() 79 self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict() 80 self._markers: Dict[str, int] = defaultdict(int) 81 self._cur_module_name: str = "" 82 self._op_index: int = 0 83 self._num_cuda_retries: int = 0 84 85 @no_type_check 86 def start_monitor(self, root_module: nn.Module) -> None: 87 """ 88 Register module hooks and entering ``MemoryProfileDispatchMode``. 89 90 This enables operator level memory stats can be tracked during module runtime. 91 """ 92 self._clear_state() 93 root_module.__setattr__("_memory_tracker_is_root", True) 94 for name, m in root_module.named_modules(): 95 if m is not root_module: 96 m.__setattr__("_memory_tracker_is_root", False) 97 # fused_proxy_group does not support hooks 98 if ".fused_proxy_grouped_embedding_bag" in name: 99 continue 100 # hook ordering with other hooks added by users is not managed, so 101 # the memory stats tracked here may not completely accurate. 102 h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name)) 103 h2 = m.register_forward_hook(self._create_post_forward_hook(name)) 104 # it does not work well with jagged tensor somehow, the root cause is not 105 # clear and remove it for now as it does not really capture important info. 106 # h3 = m.register_backward_hook(self._create_backward_hook(name)) 107 self._hooks.extend([h1, h2]) 108 torch.cuda.empty_cache() 109 assert getattr(self, "profile_mode", None) is None 110 self.profile_mode = MemoryProfileDispatchMode(self) 111 self.profile_mode.__enter__() 112 113 @no_type_check 114 def stop(self) -> None: 115 """ 116 Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. 117 118 Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``. 119 """ 120 self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0) 121 122 for h in self._hooks: 123 h.remove() 124 self._hooks.clear() 125 assert getattr(self, "profile_mode", None) is not None 126 self.profile_mode.__exit__(None, None, None) 127 self.profile_mode = None 128 129 @no_type_check 130 def summary(self, top: int = 20) -> None: 131 """ 132 Print out the top operators that generate the most memories. 133 134 The number of the top operators can be configured. 135 """ 136 op_diff: Dict[str, float] = defaultdict(float) 137 op_name, previous_allocated_memory = self.memories_allocated[0] 138 for i in range(1, self._op_index): 139 op_name, current_allocated_memory = self.memories_allocated[i] 140 op_diff[op_name] = current_allocated_memory - previous_allocated_memory 141 previous_allocated_memory = current_allocated_memory 142 143 print("------------------------------------------------") 144 print(f"The number of cuda retries are: {self._num_cuda_retries}") 145 print(f"Top {top} ops that generates memory are:") 146 for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ 147 :top 148 ]: 149 print(f"{k}: {v}MB") 150 print("------------------------------------------------") 151 152 @no_type_check 153 def show_traces(self, path: str = "") -> None: 154 import matplotlib.pyplot as plt 155 156 def _plot_figure(x, y_values, labels): 157 min_val = min(list(chain(*y_values))) * 0.999 158 max_val = max(list(chain(*y_values))) * 1.001 159 plt.figure() 160 for y, label in zip(y_values, labels): 161 plt.plot(x, y, label=label) 162 plt.xlabel("# Operator Calls") 163 plt.ylabel("Memory (MB)") 164 plt.legend() 165 for marker_name, marker in self._markers.items(): 166 if marker_name == "fw_bw_boundary": 167 plt.plot( 168 [marker, marker], 169 [min_val, max_val], 170 "r", 171 lw=2, 172 label=marker_name, 173 ) 174 else: 175 plt.plot( 176 [marker, marker], 177 [min_val, max_val], 178 "k-", 179 lw=2, 180 label=marker_name, 181 ) 182 183 if path != "": 184 self.load(path) 185 186 y_1 = [gb for (name, gb) in self.memories_allocated.values()] 187 y_2 = [gb for (name, gb) in self.memories_active.values()] 188 y_3 = [gb for (name, gb) in self.memories_reserved.values()] 189 x = list(range(len(y_1))) 190 # Split figures when there is big difference between 191 # "reserved_memory" and "allocated_memory" or "active_memory". 192 _plot_figure( 193 x, 194 [list(y_1), list(y_2), list(y_3)], 195 ["allocated_memory", "active_memory", "reserved_memory"], 196 ) 197 _plot_figure(x, [list(y_1)], ["allocated_memory"]) 198 _plot_figure(x, [list(y_2)], ["active_memory"]) 199 _plot_figure(x, [list(y_3)], ["reserved_memory"]) 200 201 def save_stats(self, path: str) -> None: 202 """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook.""" 203 stats = { 204 "memories_allocated": self.memories_allocated, 205 "memories_active": self.memories_active, 206 "memories_reserved": self.memories_reserved, 207 "markers": self._markers, 208 "num_alloc_retries": self._num_cuda_retries, 209 } 210 211 with open(path, "wb") as f: 212 pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL) 213 214 def load(self, path: str) -> None: 215 """Load the pickled memory stats to plot the traces or print the summary.""" 216 with open(path, "rb") as f: 217 stats = pickle.load(f) 218 219 self.memories_allocated = stats["memories_allocated"] 220 self.memories_active = stats["memories_active"] 221 self.memories_reserved = stats["memories_reserved"] 222 self._markers = stats["markers"] 223 self._num_cuda_retries = stats["num_alloc_retries"] 224 225 def _create_pre_forward_hook(self, name: str) -> Callable: 226 """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" 227 228 def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: 229 self._cur_module_name = f"{name}.forward" 230 if ( 231 hasattr(module, "_memory_tracker_is_root") 232 and module._memory_tracker_is_root 233 ): 234 self._add_marker("fw_start") 235 236 return _pre_forward_hook 237 238 def _create_post_forward_hook(self, name: str) -> Callable: 239 """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass.""" 240 241 def _post_forward_hook( 242 module: nn.Module, 243 inputs: Sequence[torch.Tensor], 244 outputs: Sequence[torch.Tensor], 245 ) -> None: 246 if ( 247 hasattr(module, "_memory_tracker_is_root") 248 and module._memory_tracker_is_root 249 ): 250 self._add_marker("fw_bw_boundary") 251 252 return _post_forward_hook 253 254 def _create_backward_hook(self, name: str) -> Callable: 255 """Insert the current module name with backward prefix for the operator name.""" 256 257 def _backward_hook( 258 module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor 259 ) -> None: 260 self._cur_module_name = f"{name}.backward" 261 262 return _backward_hook 263 264 @no_type_check 265 def _record_memory_stats(self, fn_name: str) -> None: 266 """ 267 Record current memory allocated, current memory active and current memory reserved. 268 269 The memory stats dict is indexed with ``self._op_index``. 270 """ 271 memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB 272 memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB 273 memory_active: float = ( 274 torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB 275 ) 276 self.memories_allocated[self._op_index] = (fn_name, memory_allocated) 277 self.memories_reserved[self._op_index] = (fn_name, memory_reserved) 278 self.memories_active[self._op_index] = (fn_name, memory_active) 279 self._op_index += 1 280 281 def _add_marker(self, marker_name: str) -> None: 282 """Set the marker's x-axis value.""" 283 marker_val = len(self.memories_allocated.values()) 284 self._markers[marker_name] = marker_val 285 286 def _clear_state(self) -> None: 287 """Clear states when start_monitor() is called.""" 288 self._operator_names.clear() 289 self.memories_allocated.clear() 290 self.memories_active.clear() 291 self.memories_reserved.clear() 292 self._markers.clear() 293 self._cur_module_name = "" 294 self._op_index = 0 295 self._num_cuda_retries = 0 296