1import math 2import os 3import re 4import warnings 5from copy import deepcopy 6from enum import auto, Enum 7from functools import partial, wraps 8from typing import ( 9 Any, 10 Callable, 11 Dict, 12 List, 13 Optional, 14 Set, 15 Tuple, 16 Type, 17 TYPE_CHECKING, 18 Union, 19) 20from typing_extensions import Self 21 22import torch 23from torch import nn, optim 24from torch.distributed._tools.mod_tracker import ModTracker 25from torch.optim.optimizer import ( 26 register_optimizer_step_post_hook, 27 register_optimizer_step_pre_hook, 28) 29from torch.utils._python_dispatch import ( 30 is_traceable_wrapper_subclass, 31 TorchDispatchMode, 32) 33from torch.utils._pytree import tree_flatten, tree_map_only 34from torch.utils.weak import WeakIdKeyDictionary, weakref 35 36 37if TYPE_CHECKING: 38 from torch.utils.hooks import RemovableHandle 39 40# This value is hard-coded here: 41# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 42_PYTORCH_MIN_ALLOCATE = ( 43 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 44) 45_TOTAL_KEY = "Total" 46 47__all__ = ["MemTracker"] 48 49 50class _RefType(str, Enum): 51 """Base Class for defining memory reference types, categorizing tensors based on their usage within a model.""" 52 53 54class _State(str, Enum): 55 """Base Class for defining module state to capture snapshots .""" 56 57 58class _MemRefType(_RefType): 59 """ 60 An enum to define memory reference types, categorizing tensors based on their usage within a model. 61 62 - PARAM: Tensors registered as nn.Parameter within modules. 63 - BUFFER: Tensors registered as nn.Buffer within modules. 64 - GRAD: Gradients associated with parameters. 65 - ACT: Tensors produced during the forward pass and recomputation in activation checkpointing. 66 - TMP: Temporary memory used during the backward pass, including gradients of activations. 67 - OPT: Tensors holding optimizer states. 68 - OTH: Tensors registered via `track_external` that do not fit the above categories. 69 """ 70 71 PARAM = "Parameter" 72 BUFFER = "Buffer" 73 GRAD = "Gradient" 74 ACT = "Activation" 75 TEMP = "Temp" 76 OPT = "Optstate" 77 OTH = "Other" 78 79 80class _ModState(_State): 81 """ 82 An enum to define the state of a module. 83 84 - PRE_FW: The module is about to run the forward pass. 85 - POST_FW: The module has finished running the forward pass. 86 - PEAK_FW: The module has reached the peak memory usage during the forward pass. 87 - PRE_BW: The module is about to run the backward pass. 88 - PRE_FW_AC: The module is about to run the forward pass with activation checkpointing. 89 - POST_FW_AC: The module has finished running the forward pass with activation checkpointing. 90 - POST_BW: The module has finished running the backward pass. 91 - PEAK_BW: The module has reached the peak memory usage during the backward pass. 92 """ 93 94 PRE_FW = "Pre-Forward" 95 POST_FW = "Post-Forward" 96 PEAK_FW = "Peak-Forward" 97 PRE_BW = "Pre-Backward" 98 PRE_FW_AC = "Pre-Forward-AC" 99 POST_FW_AC = "Post-Forward-AC" 100 POST_BW = "Post-Backward" 101 PEAK_BW = "Peak-Backward" 102 103 104class _ModMemStats: 105 """ 106 A class to store the memory statistics of a module. 107 108 Args: 109 mod_fqn (str): The fully qualified name of the module. 110 Attributes: 111 mod_fqn (str): The fully qualified name of the module. 112 parameter_mem (int): The memory usage of the parameters of the module. 113 buffer_mem (int): The memory usage of the buffers of the module. 114 input_mem (int): The memory usage of the inputs to the module. 115 output_mem (int): The memory usage of the outputs from the module. 116 snapshots (Dict[_ModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots 117 of the module at different states defined by ``_ModState``. 118 Note: 119 The memory snapshot is stored as a dictionary - Dict[torch.device, Dict[str, int]], where each key is a device, 120 and each value is another dictionary with keys as memory reference types defined by `_MemRefType` and 121 values as the memory consumed in bytes. 122 """ 123 124 def __init__(self, mod_fqn: str): 125 self.mod_fqn = mod_fqn 126 self.parameter_mem: int 127 self.buffer_mem: int 128 self.input_mem: int 129 self.output_mem: int 130 self.local_peak: Dict[torch.device, int] = {} 131 self.snapshots: Dict[_ModState, List[Dict[torch.device, Dict[str, int]]]] = {} 132 133 134class _WeakRefInfo: 135 """ 136 Manages memory statistics and device attributes for tensor storages. 137 """ 138 139 def __init__( 140 self, size: int, element_size: int, device: torch.device, reftype: _RefType 141 ) -> None: 142 """ 143 Initializes the ``_WeakRefInfo`` object with tensor storage properties. 144 145 Args: 146 size (int): The number of elements in the tensor storage. 147 element_size (int): The size of each element in the tensor storage. 148 device (torch.device): The device on which the tensor is allocated. 149 reftype (_RefType): The reference type of the tensor. 150 """ 151 self.size = size 152 self.element_size = element_size 153 self.reftype = reftype 154 self.device = device 155 self.mem_consumed = self._calculate_mem_consumed() 156 157 def _calculate_mem_consumed(self) -> int: 158 """ 159 Calculates the memory consumed by the tensor storage, considering device-specific allocation rules. 160 161 Returns: 162 int: The memory consumed in bytes. 163 """ 164 mem = self.size * self.element_size 165 if self.device.type == "cuda": 166 return math.ceil((mem) / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE 167 return mem 168 169 def update_mem_consumed(self, st: torch.UntypedStorage) -> int: 170 """ 171 Updates and returns the memory consumed if the storage size has changed. 172 173 Args: 174 st (torch.UntypedStorage): The tensor storage to check for size updates. 175 176 Returns: 177 int: The updated memory consumed in bytes. 178 """ 179 if st.size() != self.size: 180 self.size = st.size() 181 self.mem_consumed = self._calculate_mem_consumed() 182 return self.mem_consumed 183 184 @staticmethod 185 def get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]: 186 """ 187 Recursively extracts untyped storages from a tensor or its subclasses. 188 189 Args: 190 t (torch.Tensor): The tensor to extract storages from. 191 192 Returns: 193 Set[torch.UntypedStorage]: A set of untyped storages. 194 """ 195 unflattened_tensors = [t] 196 flattened_tensor_storages = set() 197 while len(unflattened_tensors) > 0: 198 obj = unflattened_tensors.pop() 199 if is_traceable_wrapper_subclass(obj): 200 attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] 201 unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) 202 else: 203 if not hasattr(obj, "untyped_storage"): 204 warnings.warn( 205 f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", 206 category=UserWarning, 207 stacklevel=2, 208 ) 209 else: 210 flattened_tensor_storages.add(obj.untyped_storage()) 211 return flattened_tensor_storages 212 213 @classmethod 214 def create_winfo( 215 cls, 216 st: torch.UntypedStorage, 217 device: torch.device, 218 reftype: _RefType, 219 callback: Optional[Callable[[Self, weakref.ref], Any]] = None, 220 ) -> Tuple[Self, weakref.ref]: 221 """ 222 Creates a new ``_WeakRefInfo`` instance and a weak reference to a ``torch.UntypedStorage`` object, 223 optionally attaching a callback to the weak reference. 224 225 Args: 226 st (torch.UntypedStorage): The storage object for which to create the weak reference info. 227 device (torch.device): The device associated with the storage object. 228 reftype (_RefType): The type of reference, used to categorize the storage. 229 callback (Optional[Callable[[Self, weakref.ref]]]): A callback function that is called when 230 the storage object is about to be finalized (garbage collected). The callback function 231 should accept two arguments: the ``_WeakRefInfo`` instance and the weak reference to the storage. 232 Returns: 233 Tuple[Self, weakref.ref]: A tuple containing the newly created ``_WeakRefInfo`` instance and the 234 weak reference to the storage object. The weak reference may have an attached callback if provided. 235 """ 236 237 winfo = cls(st.size(), st.element_size(), device, reftype) 238 w_st = weakref.ref(st, partial(callback, winfo) if callback else None) 239 return winfo, w_st 240 241 242def _get_mem_divisor(units: str) -> int: 243 unit_dict = {"B": 1, "KiB": 2**10, "MiB": 2**20, "GiB": 2**30} 244 if units in unit_dict: 245 return unit_dict[units] 246 else: 247 raise ValueError( 248 f"Unsupported unit: {units}. Supported units are: {', '.join(unit_dict.keys())}" 249 ) 250 251 252def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]: 253 return value if divisor == 1 else round(value / divisor, precision) 254 255 256def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) -> None: 257 if len(snapshot) == 0: 258 print("No memory tracked.") 259 return 260 divisor = _get_mem_divisor(units) 261 for dev, dev_snap in snapshot.items(): 262 if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: 263 continue 264 print( 265 f"Device: {dev}", 266 *( 267 f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}" 268 for k, v in dev_snap.items() 269 ), 270 sep="\n", 271 ) 272 273 274def _print_snapshot_tabular( 275 snapshot: Dict[torch.device, Dict[str, int]], units: str 276) -> None: 277 if len(snapshot) == 0: 278 print("No memory tracked.") 279 return 280 try: 281 from tabulate import tabulate 282 except ImportError as err: 283 raise ImportError( 284 "Please install tabulate to use the tabulate option." 285 ) from err 286 divisor = _get_mem_divisor(units) 287 table_data = [] 288 key_list = list(next(iter(snapshot.values())).keys()) 289 headers = ["Device"] + [f"{key}" for key in key_list] 290 291 for dev, dev_snap in snapshot.items(): 292 if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: 293 continue 294 row = [str(dev)] 295 row.extend(f"{_rounding_fn(v, divisor, 2)} {units}" for v in dev_snap.values()) 296 table_data.append(row) 297 print(tabulate(table_data, headers=headers, tablefmt="rst")) 298 299 300def _print_state_snapshots( 301 snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str 302) -> None: 303 for state, snapshot_list in snapshots.items(): 304 print(f"{state}") 305 for i, snapshot in enumerate(snapshot_list): 306 print(f"# {i + 1}:") 307 _print_snapshot(snapshot, units) 308 print() 309 310 311def _print_state_snapshots_tabular( 312 snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str 313) -> None: 314 try: 315 from tabulate import tabulate 316 except ImportError as err: 317 raise ImportError( 318 "Please install tabulate to use the tabulate option." 319 ) from err 320 321 table_data = [] 322 last_state_call = None 323 divisor = _get_mem_divisor(units) 324 for state, snapshot_list in snapshots.items(): 325 for i, snapshot in enumerate(snapshot_list): 326 state_call = f"{state} # {i + 1}" 327 for dev, dev_snap in snapshot.items(): 328 if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: 329 continue 330 row = { 331 "State & Call": ( 332 state_call if state_call != last_state_call else "" 333 ), 334 "Device": str(dev), 335 } 336 last_state_call = state_call 337 for k, v in dev_snap.items(): 338 row[f"{k}"] = f"{_rounding_fn(v, divisor, 2)} {units}" 339 table_data.append(row) 340 print(tabulate(table_data, headers="keys", tablefmt="rst")) 341 342 343class _UpdateType(Enum): 344 # These are used for tracking updates to the continuouly maintained memory snapshot. 345 # ADD - When a new tensor storage is tracked 346 # DEL - When a tensor storage is about to be finalized (garbage collected). 347 # REF - When a tensor reference is updated, for instance, the gradients are marked as 348 # generic backward reference types until the grad_hook categorizes them as gradients. 349 # SIZE - When a tensor's storage is resized. 350 ADD = auto() 351 DEL = auto() 352 REF = auto() 353 SIZE = auto() 354 355 356class MemTracker(TorchDispatchMode): 357 """ 358 A TorchDispatchMode to track, categorize and attribute the tensor memory created or accessed within its context. 359 360 It categorizes the tracked tensors as parameters, buffers, activations, gradients, temporary memory and optimizer states 361 as defined by ``_MemRefType`` within its context. It captures memory `snapshots` for the modules, called within its context, 362 at various states defined by ``_ModState``. 363 364 Attributes: 365 memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key 366 is a reference to a module, and each value is a ``_ModMemStats`` object that stores the memory 367 statistics of the module. 368 369 Note: 370 The MemTracker should be used as a context manager. The modules, optimizers, and any other tensors created within 371 the context of MemTracker will be tracked by default. Any tensors or stateful objects such as modules, optimizers etc. 372 that need to be tracked but are created outside the MemTracker should be registered using the `track_external` method. 373 The `track_external` method should be called before the MemTracker is used. Any tensors created outside the ``MemTracker`` 374 and not supplied to the `track_external` method will not be tracked by the ``MemTracker``. 375 376 Example usage: 377 378 .. code-block:: python 379 380 module = ... 381 optimizer = ... 382 inp = ... 383 mem_tracker = MemTracker() 384 mem_tracker.track_external(module, optimizer, inp) 385 with mem_tracker as mt: 386 loss = module(inp) 387 print("After Forward:") 388 mt.display_snapshot("current") 389 loss.backward() 390 optimizer.step() 391 optimizer.zero_grad() 392 mt.display_snapshot("peak") 393 mt.display_modulewise_snapshots(depth = 3, units = "MiB") 394 395 Known Limitations: 396 - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. 397 - Resizing tensor storages directly by using non-Tensor methods other than using ``torch.Untyped_Storage.resize_`` 398 is not tracked. File a Github issue if you have use-cases for this. 399 - If the tensors are not traceable or wrappable subclasses of ``torch.Tensor``, then the tracker does not know how to 400 track their storages. File a Github issue if you have use-cases for this. 401 - During AC in the backward pass there might be misattribution between activation and temp memory, but the peak memory 402 will be tracked accurately. This will be fixed in the next update by hooking intricately with ``torch.uitls.checkpoint``. 403 """ 404 405 def __init__(self) -> None: 406 self.memory_tracking = WeakIdKeyDictionary() 407 self._curr_mem_snap: Dict[torch.device, Dict[str, int]] = {} 408 self._peak_mem: Dict[torch.device, int] = {} 409 self._peak_mem_snap: Dict[torch.device, Dict[str, int]] = {} 410 self._param_to_grad_hook_handles = WeakIdKeyDictionary() 411 self._optimizer_hook_handles: Optional[ 412 Tuple[RemovableHandle, RemovableHandle] 413 ] = None 414 # Dictionary to store the ``_WeakRefInfo`` instances corresponding to each tensor's storage. 415 self._WINFO = WeakIdKeyDictionary() 416 self._mod_tracker = ModTracker() 417 # This is a general memory tracker which can be used with any ``_RefType`` subclass 418 self._ref_class: Type[_RefType] = _MemRefType 419 # Flags to track if we are in the AC region or optimizer step region 420 self._in_opt: bool = False 421 self._in_ac: bool = False 422 # Weak references to the topmost AC module currently active 423 self._ac_mod: Optional[weakref.ref] = None 424 self._orig_resize = torch.UntypedStorage.resize_ 425 426 def _update_snap( 427 self, 428 u_type: _UpdateType, 429 winfo: _WeakRefInfo, 430 old_mem_consumed: Optional[int] = None, 431 old_reftype: Optional[_RefType] = None, 432 ) -> None: 433 # Initialize a flag to track if the total memory might drop to zero after updates. 434 maybe_zero = False 435 # Ensure the device entry exists in the current memory snapshot, initializing if necessary. 436 dev_snap = self._curr_mem_snap.setdefault( 437 winfo.device, dict.fromkeys(self._ref_class, 0) 438 ) 439 dev_snap.setdefault(_TOTAL_KEY, 0) 440 # Handle different types of updates based on the update type (`u_type`). 441 if u_type == _UpdateType.ADD: 442 # Increase the memory consumed for the specific reference type and update the total. 443 dev_snap[winfo.reftype] += winfo.mem_consumed 444 dev_snap[_TOTAL_KEY] += winfo.mem_consumed 445 elif u_type == _UpdateType.DEL: 446 # Decrease the memory consumed for the specific reference type and reduce the total. 447 dev_snap[winfo.reftype] -= winfo.mem_consumed 448 dev_snap[_TOTAL_KEY] -= winfo.mem_consumed 449 maybe_zero = True 450 elif u_type == _UpdateType.REF: 451 assert old_reftype is not None 452 # Adjust memory consumption between two reference types within the same device. 453 dev_snap[old_reftype] -= winfo.mem_consumed 454 dev_snap[winfo.reftype] += winfo.mem_consumed 455 elif u_type == _UpdateType.SIZE: 456 assert old_mem_consumed is not None 457 # Adjust the memory consumed for a reference type due to a change in size. 458 change = winfo.mem_consumed - old_mem_consumed 459 dev_snap[winfo.reftype] += change 460 dev_snap[_TOTAL_KEY] += change 461 maybe_zero = True 462 else: 463 raise ValueError(f"Invalid update type: {u_type}") 464 # Check if the total memory for the device has dropped to zero. 465 if maybe_zero: 466 if self._curr_mem_snap[winfo.device][_TOTAL_KEY] == 0: 467 # Remove the device entry from the memory snapshot if the total memory is zero. 468 del self._curr_mem_snap[winfo.device] 469 470 def _update_and_maybe_create_winfos( 471 self, 472 t: torch.Tensor, 473 reftype: _RefType, 474 update_existing: bool = False, 475 ) -> Set[_WeakRefInfo]: 476 sts = _WeakRefInfo.get_untyped_storages(t) 477 winfos = set() 478 for st in sts: 479 # Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary. 480 winfo, _ = self._WINFO.get(st, (None, None)) 481 if winfo is not None: 482 # If ``_WeakRefInfo`` exists, check if the reference type needs to be updated. 483 old_reftype = winfo.reftype 484 if old_reftype != reftype: 485 # Update the reference type and apply changes via ``_update_snap``. 486 winfo.reftype = reftype 487 self._update_snap(_UpdateType.REF, winfo, old_reftype=old_reftype) 488 winfos.add(winfo) 489 elif update_existing: 490 # If no existing ``_WeakRefInfo`` is found and update_existing is True, raise an error. 491 raise KeyError("No existing winfo found") 492 else: 493 # If no existing _WeakRefInfo is found and update_existing is False, create a new ``_WeakRefInfo``. 494 winfo, w_st = _WeakRefInfo.create_winfo( 495 st, t.device, reftype, self._delete_callback 496 ) 497 # Store the new ``_WeakRefInfo`` and its weak reference in the tracking dictionary. 498 self._WINFO[st] = (winfo, w_st) 499 # Update the snapshot for the newly added ``_WeakRefInfo``. 500 if winfo.mem_consumed > 0: 501 self._update_snap(_UpdateType.ADD, winfo) 502 winfos.add(winfo) 503 return winfos 504 505 def _delete_callback(self, winfo: _WeakRefInfo, w_st: weakref.ref) -> None: 506 # Callback to be called when the storage object corresponding to the ``_WeakRefInfo`` 507 # instance is about to be finalized. 508 if winfo.mem_consumed > 0: 509 self._update_snap(_UpdateType.DEL, winfo) 510 511 def _track_resize(self) -> None: 512 # Need to monkey-patch this because ``torch.UntypedStorage.resize_`` is not captured 513 # by ``TorchDispatchMode``. 514 @wraps(self._orig_resize) 515 def resize_(st: torch.UntypedStorage, size: int) -> None: 516 self._orig_resize(st, size) 517 winfo, _ = self._WINFO.get(st, (None, None)) 518 if winfo is not None and winfo.size != st.size(): 519 old_mem_consumed = winfo.mem_consumed 520 winfo.update_mem_consumed(st) 521 self._update_snap( 522 _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed 523 ) 524 525 torch.UntypedStorage.resize_ = resize_ # type: ignore[method-assign, assignment] 526 527 def _restore_resize(self) -> None: 528 torch.UntypedStorage.resize_ = self._orig_resize # type: ignore[method-assign] 529 530 def _update_peak_stats(self, peak_state: _State) -> None: 531 # We first capture the current memory snapshot of the current tracker state then, 532 # We step through each of the modules we have tracked so far in ``memory_tracking`` 533 # and check if it is currently active by querying ``_mod_tracker.parents`` 534 # If it is active, we update the per device peak memory usage for the module 535 # corresponding to the ``_State`` which can be ``PEAK_FW`` or ``PEAK_BW``. 536 curr_snap = self._curr_mem_snap 537 538 for mod_stats in self.memory_tracking.values(): 539 if mod_stats.mod_fqn in self._mod_tracker.parents: 540 if peak_state in mod_stats.snapshots: 541 for dev, dev_snap in curr_snap.items(): 542 if mod_stats.local_peak.get(dev, 0) < dev_snap[_TOTAL_KEY]: 543 mod_stats.local_peak[dev] = dev_snap[_TOTAL_KEY] 544 mod_stats.snapshots[peak_state][-1][dev] = deepcopy( 545 dev_snap 546 ) 547 548 for dev, dev_snap in curr_snap.items(): 549 if self._peak_mem.get(dev, 0) < dev_snap[_TOTAL_KEY]: 550 self._peak_mem[dev] = dev_snap[_TOTAL_KEY] 551 self._peak_mem_snap[dev] = deepcopy(dev_snap) 552 553 def _track(self, reftype: _RefType, t: torch.Tensor) -> None: 554 # Get the storages of the tensor and check if we have already tracked them. 555 # If yes, then check if the storage size has changed and update the current snapshot. 556 # Else create a new ``_WeakRefInfo`` instance and add it to the dictionary. 557 sts = _WeakRefInfo.get_untyped_storages(t) 558 for st in sts: 559 winfo, _ = self._WINFO.get(st, (None, None)) 560 if winfo is not None: 561 if winfo.size != st.size(): 562 old_mem_consumed = winfo.mem_consumed 563 winfo.update_mem_consumed(st) 564 self._update_snap( 565 _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed 566 ) 567 return 568 else: 569 winfo, w_st = _WeakRefInfo.create_winfo( 570 st, t.device, reftype, self._delete_callback 571 ) 572 self._WINFO[st] = (winfo, w_st) 573 # Update the current snapshot for the newly added ``_WeakRefInfo``. 574 if winfo.mem_consumed > 0: 575 self._update_snap(_UpdateType.ADD, winfo) 576 577 def get_tracker_snapshot( 578 self, type: str = "current" 579 ) -> Dict[torch.device, Dict[str, int]]: 580 """ 581 Capture a snapshot of the memory usage breakdown per device, based on the specified type. 582 583 Args: 584 type (str): The type of snapshot to capture. Can be "current" for the current memory usage or "peak" for the 585 peak memory usage. Defaults to "current". 586 Returns: 587 Dict[torch.device, Dict[str, int]]: A dictionary where each key is a torch.device, and each value is another 588 dictionary. This inner dictionary has keys representing memory reference 589 types as defined in ``_MemRefType`` and values representing the amount of 590 memory consumed in bytes. 591 Raises: 592 ValueError: If an invalid type is specified. 593 """ 594 if type == "current": 595 return deepcopy(self._curr_mem_snap) 596 elif type == "peak": 597 return deepcopy(self._peak_mem_snap) 598 else: 599 raise ValueError(f"Invalid type {type}") 600 601 def _track_module_params_and_buffers( 602 self, module: nn.Module, install_grad_hooks: bool = True 603 ) -> Tuple[int, int]: 604 # Track the parameters and buffers of the module if not already tracked. 605 # If the parameters have gradients, track the gradients as well. 606 # If install_grad_hooks is True, install a gradient hook on the parameters 607 # to track the gradients, if it has not already been installed. 608 # Return the total memory consumed by the parameters and buffers. 609 def _grad_hook(grad: torch.Tensor) -> None: 610 self._update_and_maybe_create_winfos( 611 grad, 612 _MemRefType.GRAD, 613 ) 614 615 param_memory = 0 616 for param in module.parameters(): 617 winfos = self._update_and_maybe_create_winfos( 618 param, 619 _MemRefType.PARAM, 620 ) 621 param_memory += sum(winfo.mem_consumed for winfo in winfos) 622 if param.grad is not None: 623 self._update_and_maybe_create_winfos( 624 param.grad, 625 _MemRefType.GRAD, 626 ) 627 if ( 628 self._param_to_grad_hook_handles.get(param, None) is None 629 and install_grad_hooks 630 ): 631 grad_hook_handle = param.register_hook(_grad_hook) 632 post_acc_grad_hook_handle = param.register_post_accumulate_grad_hook( 633 lambda p: (_grad_hook(p.grad)) 634 ) 635 self._param_to_grad_hook_handles[param] = ( 636 grad_hook_handle, 637 post_acc_grad_hook_handle, 638 ) 639 buffer_memory = 0 640 for buffer in module.buffers(): 641 winfos = self._update_and_maybe_create_winfos( 642 buffer, 643 _MemRefType.BUFFER, 644 ) 645 buffer_memory += sum(winfo.mem_consumed for winfo in winfos) 646 return (param_memory, buffer_memory) 647 648 def _track_inputs_or_outputs(self, args: Any) -> int: 649 # Calculate the memory consumed by the inputs or outputs of the module. 650 input_or_output_memory = 0 651 652 def add_inps_or_outs(t: torch.Tensor) -> None: 653 nonlocal input_or_output_memory 654 sts = _WeakRefInfo.get_untyped_storages(t) 655 for st in sts: 656 winfo, _ = self._WINFO.get(st, (None, None)) 657 if winfo is not None: 658 input_or_output_memory += winfo.mem_consumed 659 660 tree_map_only(torch.Tensor, add_inps_or_outs, args) 661 return input_or_output_memory 662 663 def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None: 664 # This is installed as a pre-fwd user hook with ``ModTracker.`` Based on the following cases we 665 # set the state and capture the memory snapshot for the module. 666 # Case 1: If the module is not in the ``memory_tracking`` dictionary, we track the parameters, buffers, 667 # input and output memory of the module. Create a new ``_ModMemStats`` instance for the module 668 # and add it to the ``memory_tracking`` dictionary. 669 # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means 670 # we are in the AC region. We check if this is the top most module in the AC region. If it is, 671 # we store a weak reference and set the flag ``_in_ac`` to True. 672 # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means 673 # this module is called for the second time. If it is a root module, that means we are in the next 674 # iteration and we error out. If it is not a root module, that means it's a submodule that is being 675 # used multiple times in the same iteration, which we allow and track. 676 # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module. 677 mod_name = self._mod_tracker.get_known_fqn(module) 678 assert mod_name is not None 679 if module not in self.memory_tracking: 680 mod_stats = _ModMemStats(mod_name) 681 param_mem, buffer_mem = self._track_module_params_and_buffers( 682 module, install_grad_hooks=True 683 ) 684 input_mem = self._track_inputs_or_outputs(inputs) 685 mod_stats.parameter_mem = param_mem 686 mod_stats.buffer_mem = buffer_mem 687 mod_stats.input_mem = input_mem 688 self.memory_tracking[module] = mod_stats 689 state = _ModState.PRE_FW 690 691 elif self._mod_tracker.is_bw: 692 mod_stats = self.memory_tracking[module] 693 state = _ModState.PRE_FW_AC 694 if self._ac_mod is None: 695 self._ac_mod = weakref.ref(module) 696 self._in_ac = True 697 else: 698 parents = set(self._mod_tracker.parents) - {mod_name} 699 if len(parents) == 1 and "Global" in parents: 700 raise NotImplementedError( 701 "MemTracker does not support memory tracking for multiple iterative calls." 702 " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" 703 " or file a github issue if you need this feature." 704 ) 705 mod_stats = self.memory_tracking[module] 706 state = _ModState.PRE_FW 707 input_mem = self._track_inputs_or_outputs(inputs) 708 mod_stats.input_mem = input_mem 709 710 mem_snapshot = self.get_tracker_snapshot() 711 if state == _ModState.PRE_FW: 712 mod_stats.local_peak = { 713 dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() 714 } 715 mod_stats.snapshots.setdefault(_ModState.PEAK_FW, []).append(mem_snapshot) 716 mod_stats.snapshots.setdefault(state, []).append(deepcopy(mem_snapshot)) 717 718 def _post_fw_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> None: 719 # This is installed as a post-fwd user hook with ``ModTracker``. Based on the following cases we 720 # set the state and capture the memory snapshot for the module. 721 # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module 722 # in the AC region, we set the flag ``_in_ac`` to False. 723 # Case 2: This is called in forward so we calculate the output memory 724 # of the module and update its mod_stats. 725 mod_stats = self.memory_tracking[module] 726 if self._mod_tracker.is_bw: 727 state = _ModState.POST_FW_AC 728 if self._ac_mod is not None and self._ac_mod() is module: 729 self._ac_mod = None 730 self._in_ac = False 731 else: 732 state = _ModState.POST_FW 733 output_mem = self._track_inputs_or_outputs(outputs) 734 mod_stats.output_mem = output_mem 735 mod_stats.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) 736 737 def _pre_bw_hook(self, module: nn.Module, args: Any) -> None: 738 # This is installed as a pre-bwd user hook with ``ModTracker``. We set the state and capture the 739 # snapshot for the module. We also initialize the ``local_peak`` and ``PEAK_BW`` snapshot for it. 740 # If the module is None, we skip the hook. 741 # This can happen since this installed inside a multi-grad hook on the module's output tensors 742 # and the module itself may not be alive during backward. 743 if module is None: 744 warnings.warn("Module is None. Skipping PRE_BW hook.", stacklevel=2) 745 return 746 mod_stats = self.memory_tracking[module] 747 mem_snapshot = self.get_tracker_snapshot() 748 mod_stats.local_peak = { 749 dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() 750 } 751 mod_stats.snapshots.setdefault(_ModState.PEAK_BW, []).append(mem_snapshot) 752 mod_stats.snapshots.setdefault(_ModState.PRE_BW, []).append( 753 deepcopy(mem_snapshot) 754 ) 755 756 def _post_bw_hook(self, module: nn.Module, args: Any) -> None: 757 # This is installed as a post-bwd user hook with ``ModTracker``. We set the state and capture the 758 # snapshot for the module if it is not None. 759 # This can happen since this installed inside a multi-grad hook on the module's input tensors 760 # and the module itself may not be alive during backward. 761 if module is None: 762 warnings.warn("Module is None. Skipping POST_BW hook.", stacklevel=2) 763 return 764 mod_stats = self.memory_tracking[module] 765 mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( 766 self.get_tracker_snapshot() 767 ) 768 769 def _track_optimizer_states( 770 self, reftype: _RefType, optimizer: optim.Optimizer 771 ) -> None: 772 for states in optimizer.state.values(): 773 for val in states.values(): 774 if isinstance(val, torch.Tensor): 775 self._update_and_maybe_create_winfos( 776 val, 777 reftype, 778 ) 779 780 def _register_global_optimizer_hook(self) -> None: 781 # Register a hook on the optimizer step to track the optimizer states. 782 # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, 783 # and also tracks any optimizer states that are created during the optimizer step. 784 def _opt_step_pre_hook( 785 optimizer: optim.Optimizer, args: Any, kwargs: Any 786 ) -> None: 787 self._in_opt = True 788 789 def _opt_step_post_hook( 790 optimizer: optim.Optimizer, args: Any, kwargs: Any 791 ) -> None: 792 self._track_optimizer_states(_MemRefType.OPT, optimizer) 793 self._in_opt = False 794 795 self._optimizer_hook_handles = ( 796 register_optimizer_step_pre_hook(_opt_step_pre_hook), 797 register_optimizer_step_post_hook(_opt_step_post_hook), 798 ) 799 800 def _deregister_param_and_optimizer_hooks(self) -> None: 801 for ( 802 grad_hook_handle, 803 post_acc_grad_hook_handle, 804 ) in self._param_to_grad_hook_handles.values(): 805 grad_hook_handle.remove() 806 post_acc_grad_hook_handle.remove() 807 self._param_to_grad_hook_handles.clear() 808 809 if self._optimizer_hook_handles is not None: 810 for handle in self._optimizer_hook_handles: 811 handle.remove() 812 self._optimizer_hook_handles = None 813 814 def track_external( 815 self, *external: Union[nn.Module, optim.Optimizer, torch.Tensor] 816 ) -> None: 817 """ 818 Track tensors and stateful objects like modules, optimizers etc. that are created outside the MemTracker. 819 820 This method should be called before the ``MemTracker`` is used. Any tensors that are not module parameters, buffers, 821 gradients activations, or optimizer states will be categorized as ``Other``. If you want them categorized with a 822 custom name, please file a GitHub issue. Any tensors created outside the MemTracker and not supplied to this 823 method will not be be tracked by ``MemTracker``. 824 825 Args: 826 *external (Union[nn.Module, optim.Optimizer, torch.Tensor]): The external modules, optimizers, and 827 tensors to be tracked. 828 """ 829 flat_external, _ = tree_flatten(external) 830 for obj in flat_external: 831 if isinstance(obj, torch.Tensor): 832 self._update_and_maybe_create_winfos( 833 obj, 834 _MemRefType.OTH, 835 ) 836 elif isinstance(obj, torch.nn.Module): 837 self._track_module_params_and_buffers(obj, install_grad_hooks=False) 838 elif isinstance(obj, optim.Optimizer): 839 self._track_optimizer_states(_MemRefType.OPT, obj) 840 else: 841 raise TypeError( 842 f"Object of type {type(obj)} is not supported for tracking. " 843 f"Only stateful objects like modules, optimizers, and tensors are supported." 844 ) 845 846 def display_snapshot( 847 self, type: str = "current", units: str = "B", tabulate: bool = False 848 ) -> None: 849 """ 850 Display the memory usage breakdown snapshot of the tracker based on the specified type and units. 851 852 Keyword args: 853 type (str): The type of snapshot to display. Can be "current" for the current memory usage or "peak" for the 854 peak memory usage. Defaults to "current". 855 units (str): The units to use for displaying memory usage. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. 856 tabulate (bool): Whether to display the snapshot in a tabular format. Defaults to False. 857 """ 858 snapshot = self.get_tracker_snapshot(type) 859 if tabulate: 860 _print_snapshot_tabular(snapshot, units) 861 else: 862 _print_snapshot(snapshot, units) 863 864 def display_modulewise_snapshots( 865 self, depth: int = 2, units: str = "B", tabulate: bool = False 866 ) -> None: 867 """ 868 Print per device memory breakdown snapshot for each module called within MemTracker. 869 870 Snapshots are displayed for the states defined by ``_ModState``. 871 The module hierarchy is displayed up to the specified depth. 872 873 Keyword Args: 874 depth (int, optional): The depth of the module hierarchy to display. Defaults to 2. 875 units (str, optional): The units to use for memory tracking. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. 876 tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False. 877 """ 878 879 def natural_sort_key(s: str) -> List[Union[int, str]]: 880 return [ 881 int(text) if text.isdigit() else text.lower() 882 for text in re.split("([0-9]+)", s) 883 ] 884 885 for mod_stats in sorted( 886 self.memory_tracking.values(), 887 key=lambda m_stats: natural_sort_key(m_stats.mod_fqn), 888 ): 889 mod_fqn = mod_stats.mod_fqn 890 mod_depth = mod_fqn.count(".") + 1 891 if mod_depth > depth: 892 continue 893 print(f"Module: {mod_fqn}") 894 if tabulate: 895 _print_state_snapshots_tabular(mod_stats.snapshots, units) 896 else: 897 _print_state_snapshots(mod_stats.snapshots, units) 898 899 def reset_mod_stats(self) -> None: 900 """ 901 Reset all the module memory stats. Clears ``memory_tracking`` dictionary. 902 """ 903 self.memory_tracking.clear() 904 905 def __enter__(self) -> "MemTracker": 906 self._register_global_optimizer_hook() 907 self._mod_tracker.register_user_hooks( 908 self._pre_fw_hook, 909 self._post_fw_hook, 910 self._pre_bw_hook, 911 self._post_bw_hook, 912 ) 913 self._track_resize() 914 self._peak_mem_snap = self.get_tracker_snapshot() 915 self._peak_mem = { 916 dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items() 917 } 918 self._mod_tracker.__enter__() 919 super().__enter__() 920 return self 921 922 def __exit__(self, *args: Any) -> None: 923 self._deregister_param_and_optimizer_hooks() 924 self._mod_tracker.clear_user_hooks() 925 self._restore_resize() 926 super().__exit__(*args) 927 self._mod_tracker.__exit__(*args) 928 929 def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def] 930 res = func(*args, **kwargs or {}) 931 # If we are tracking an optimizer state, we use the optimizer reference type. 932 # If we are in backward region and not in AC region, we use the backward reference type. 933 # Else we use the forward reference type. 934 if self._in_opt: 935 reftype = _MemRefType.OPT 936 elif self._mod_tracker.is_bw and not self._in_ac: 937 reftype = _MemRefType.TEMP 938 else: 939 reftype = _MemRefType.ACT 940 tree_map_only(torch.Tensor, partial(self._track, reftype), res) 941 peak_state = _ModState.PEAK_BW if self._mod_tracker.is_bw else _ModState.PEAK_FW 942 self._update_peak_stats(peak_state) 943 return res 944