xref: /aosp_15_r20/external/pytorch/torch/distributed/_tools/mem_tracker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import math
2import os
3import re
4import warnings
5from copy import deepcopy
6from enum import auto, Enum
7from functools import partial, wraps
8from typing import (
9    Any,
10    Callable,
11    Dict,
12    List,
13    Optional,
14    Set,
15    Tuple,
16    Type,
17    TYPE_CHECKING,
18    Union,
19)
20from typing_extensions import Self
21
22import torch
23from torch import nn, optim
24from torch.distributed._tools.mod_tracker import ModTracker
25from torch.optim.optimizer import (
26    register_optimizer_step_post_hook,
27    register_optimizer_step_pre_hook,
28)
29from torch.utils._python_dispatch import (
30    is_traceable_wrapper_subclass,
31    TorchDispatchMode,
32)
33from torch.utils._pytree import tree_flatten, tree_map_only
34from torch.utils.weak import WeakIdKeyDictionary, weakref
35
36
37if TYPE_CHECKING:
38    from torch.utils.hooks import RemovableHandle
39
40# This value is hard-coded here:
41# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117
42_PYTORCH_MIN_ALLOCATE = (
43    2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1
44)
45_TOTAL_KEY = "Total"
46
47__all__ = ["MemTracker"]
48
49
50class _RefType(str, Enum):
51    """Base Class for defining memory reference types, categorizing tensors based on their usage within a model."""
52
53
54class _State(str, Enum):
55    """Base Class for defining module state to capture snapshots ."""
56
57
58class _MemRefType(_RefType):
59    """
60    An enum to define memory reference types, categorizing tensors based on their usage within a model.
61
62        - PARAM: Tensors registered as nn.Parameter within modules.
63        - BUFFER: Tensors registered as nn.Buffer within modules.
64        - GRAD: Gradients associated with parameters.
65        - ACT: Tensors produced during the forward pass and recomputation in activation checkpointing.
66        - TMP: Temporary memory used during the backward pass, including gradients of activations.
67        - OPT: Tensors holding optimizer states.
68        - OTH: Tensors registered via `track_external` that do not fit the above categories.
69    """
70
71    PARAM = "Parameter"
72    BUFFER = "Buffer"
73    GRAD = "Gradient"
74    ACT = "Activation"
75    TEMP = "Temp"
76    OPT = "Optstate"
77    OTH = "Other"
78
79
80class _ModState(_State):
81    """
82    An enum to define the state of a module.
83
84        - PRE_FW: The module is about to run the forward pass.
85        - POST_FW: The module has finished running the forward pass.
86        - PEAK_FW: The module has reached the peak memory usage during the forward pass.
87        - PRE_BW: The module is about to run the backward pass.
88        - PRE_FW_AC: The module is about to run the forward pass with activation checkpointing.
89        - POST_FW_AC: The module has finished running the forward pass with activation checkpointing.
90        - POST_BW: The module has finished running the backward pass.
91        - PEAK_BW: The module has reached the peak memory usage during the backward pass.
92    """
93
94    PRE_FW = "Pre-Forward"
95    POST_FW = "Post-Forward"
96    PEAK_FW = "Peak-Forward"
97    PRE_BW = "Pre-Backward"
98    PRE_FW_AC = "Pre-Forward-AC"
99    POST_FW_AC = "Post-Forward-AC"
100    POST_BW = "Post-Backward"
101    PEAK_BW = "Peak-Backward"
102
103
104class _ModMemStats:
105    """
106    A class to store the memory statistics of a module.
107
108    Args:
109        mod_fqn (str): The fully qualified name of the module.
110    Attributes:
111        mod_fqn (str): The fully qualified name of the module.
112        parameter_mem (int): The memory usage of the parameters of the module.
113        buffer_mem (int): The memory usage of the buffers of the module.
114        input_mem (int): The memory usage of the inputs to the module.
115        output_mem (int): The memory usage of the outputs from the module.
116        snapshots (Dict[_ModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots
117        of the module at different states defined by ``_ModState``.
118    Note:
119        The memory snapshot is stored as a dictionary - Dict[torch.device, Dict[str, int]], where each key is a device,
120         and each value is another dictionary with keys as memory reference types defined by `_MemRefType` and
121         values as the memory consumed in bytes.
122    """
123
124    def __init__(self, mod_fqn: str):
125        self.mod_fqn = mod_fqn
126        self.parameter_mem: int
127        self.buffer_mem: int
128        self.input_mem: int
129        self.output_mem: int
130        self.local_peak: Dict[torch.device, int] = {}
131        self.snapshots: Dict[_ModState, List[Dict[torch.device, Dict[str, int]]]] = {}
132
133
134class _WeakRefInfo:
135    """
136    Manages memory statistics and device attributes for tensor storages.
137    """
138
139    def __init__(
140        self, size: int, element_size: int, device: torch.device, reftype: _RefType
141    ) -> None:
142        """
143        Initializes the ``_WeakRefInfo`` object with tensor storage properties.
144
145        Args:
146            size (int): The number of elements in the tensor storage.
147            element_size (int): The size of each element in the tensor storage.
148            device (torch.device): The device on which the tensor is allocated.
149            reftype (_RefType): The reference type of the tensor.
150        """
151        self.size = size
152        self.element_size = element_size
153        self.reftype = reftype
154        self.device = device
155        self.mem_consumed = self._calculate_mem_consumed()
156
157    def _calculate_mem_consumed(self) -> int:
158        """
159        Calculates the memory consumed by the tensor storage, considering device-specific allocation rules.
160
161        Returns:
162            int: The memory consumed in bytes.
163        """
164        mem = self.size * self.element_size
165        if self.device.type == "cuda":
166            return math.ceil((mem) / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE
167        return mem
168
169    def update_mem_consumed(self, st: torch.UntypedStorage) -> int:
170        """
171        Updates and returns the memory consumed if the storage size has changed.
172
173        Args:
174            st (torch.UntypedStorage): The tensor storage to check for size updates.
175
176        Returns:
177            int: The updated memory consumed in bytes.
178        """
179        if st.size() != self.size:
180            self.size = st.size()
181            self.mem_consumed = self._calculate_mem_consumed()
182        return self.mem_consumed
183
184    @staticmethod
185    def get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
186        """
187        Recursively extracts untyped storages from a tensor or its subclasses.
188
189        Args:
190            t (torch.Tensor): The tensor to extract storages from.
191
192        Returns:
193            Set[torch.UntypedStorage]: A set of untyped storages.
194        """
195        unflattened_tensors = [t]
196        flattened_tensor_storages = set()
197        while len(unflattened_tensors) > 0:
198            obj = unflattened_tensors.pop()
199            if is_traceable_wrapper_subclass(obj):
200                attrs, _ = obj.__tensor_flatten__()  # type: ignore[attr-defined]
201                unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
202            else:
203                if not hasattr(obj, "untyped_storage"):
204                    warnings.warn(
205                        f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
206                        category=UserWarning,
207                        stacklevel=2,
208                    )
209                else:
210                    flattened_tensor_storages.add(obj.untyped_storage())
211        return flattened_tensor_storages
212
213    @classmethod
214    def create_winfo(
215        cls,
216        st: torch.UntypedStorage,
217        device: torch.device,
218        reftype: _RefType,
219        callback: Optional[Callable[[Self, weakref.ref], Any]] = None,
220    ) -> Tuple[Self, weakref.ref]:
221        """
222        Creates a new ``_WeakRefInfo`` instance and a weak reference to a ``torch.UntypedStorage`` object,
223        optionally attaching a callback to the weak reference.
224
225        Args:
226            st (torch.UntypedStorage): The storage object for which to create the weak reference info.
227            device (torch.device): The device associated with the storage object.
228            reftype (_RefType): The type of reference, used to categorize the storage.
229            callback (Optional[Callable[[Self, weakref.ref]]]): A callback function that is called when
230                the storage object is about to be finalized (garbage collected). The callback function
231                should accept two arguments: the ``_WeakRefInfo`` instance and the weak reference to the storage.
232        Returns:
233            Tuple[Self, weakref.ref]: A tuple containing the newly created ``_WeakRefInfo`` instance and the
234            weak reference to the storage object. The weak reference may have an attached callback if provided.
235        """
236
237        winfo = cls(st.size(), st.element_size(), device, reftype)
238        w_st = weakref.ref(st, partial(callback, winfo) if callback else None)
239        return winfo, w_st
240
241
242def _get_mem_divisor(units: str) -> int:
243    unit_dict = {"B": 1, "KiB": 2**10, "MiB": 2**20, "GiB": 2**30}
244    if units in unit_dict:
245        return unit_dict[units]
246    else:
247        raise ValueError(
248            f"Unsupported unit: {units}. Supported units are: {', '.join(unit_dict.keys())}"
249        )
250
251
252def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]:
253    return value if divisor == 1 else round(value / divisor, precision)
254
255
256def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) -> None:
257    if len(snapshot) == 0:
258        print("No memory tracked.")
259        return
260    divisor = _get_mem_divisor(units)
261    for dev, dev_snap in snapshot.items():
262        if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
263            continue
264        print(
265            f"Device: {dev}",
266            *(
267                f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}"
268                for k, v in dev_snap.items()
269            ),
270            sep="\n",
271        )
272
273
274def _print_snapshot_tabular(
275    snapshot: Dict[torch.device, Dict[str, int]], units: str
276) -> None:
277    if len(snapshot) == 0:
278        print("No memory tracked.")
279        return
280    try:
281        from tabulate import tabulate
282    except ImportError as err:
283        raise ImportError(
284            "Please install tabulate to use the tabulate option."
285        ) from err
286    divisor = _get_mem_divisor(units)
287    table_data = []
288    key_list = list(next(iter(snapshot.values())).keys())
289    headers = ["Device"] + [f"{key}" for key in key_list]
290
291    for dev, dev_snap in snapshot.items():
292        if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
293            continue
294        row = [str(dev)]
295        row.extend(f"{_rounding_fn(v, divisor, 2)} {units}" for v in dev_snap.values())
296        table_data.append(row)
297    print(tabulate(table_data, headers=headers, tablefmt="rst"))
298
299
300def _print_state_snapshots(
301    snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str
302) -> None:
303    for state, snapshot_list in snapshots.items():
304        print(f"{state}")
305        for i, snapshot in enumerate(snapshot_list):
306            print(f"# {i + 1}:")
307            _print_snapshot(snapshot, units)
308    print()
309
310
311def _print_state_snapshots_tabular(
312    snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str
313) -> None:
314    try:
315        from tabulate import tabulate
316    except ImportError as err:
317        raise ImportError(
318            "Please install tabulate to use the tabulate option."
319        ) from err
320
321    table_data = []
322    last_state_call = None
323    divisor = _get_mem_divisor(units)
324    for state, snapshot_list in snapshots.items():
325        for i, snapshot in enumerate(snapshot_list):
326            state_call = f"{state} # {i + 1}"
327            for dev, dev_snap in snapshot.items():
328                if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0:
329                    continue
330                row = {
331                    "State & Call": (
332                        state_call if state_call != last_state_call else ""
333                    ),
334                    "Device": str(dev),
335                }
336                last_state_call = state_call
337                for k, v in dev_snap.items():
338                    row[f"{k}"] = f"{_rounding_fn(v, divisor, 2)} {units}"
339                table_data.append(row)
340    print(tabulate(table_data, headers="keys", tablefmt="rst"))
341
342
343class _UpdateType(Enum):
344    # These are used for tracking updates to the continuouly maintained memory snapshot.
345    # ADD - When a new tensor storage is tracked
346    # DEL - When a tensor storage is about to be finalized (garbage collected).
347    # REF - When a tensor reference is updated, for instance, the gradients are marked as
348    #       generic backward reference types until the grad_hook categorizes them as gradients.
349    # SIZE - When a tensor's storage is resized.
350    ADD = auto()
351    DEL = auto()
352    REF = auto()
353    SIZE = auto()
354
355
356class MemTracker(TorchDispatchMode):
357    """
358    A TorchDispatchMode to track, categorize and attribute the tensor memory created or accessed within its context.
359
360    It categorizes the tracked tensors as parameters, buffers, activations, gradients, temporary memory and optimizer states
361    as defined by ``_MemRefType`` within its context. It captures memory `snapshots` for the modules, called within its context,
362    at various states defined by ``_ModState``.
363
364    Attributes:
365        memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key
366        is a reference to a module, and each value is a ``_ModMemStats`` object that stores the memory
367        statistics of the module.
368
369    Note:
370        The MemTracker should be used as a context manager. The modules, optimizers, and any other tensors created within
371        the context of MemTracker will be tracked by default. Any tensors or stateful objects such as modules, optimizers etc.
372        that need to be tracked but are created outside the MemTracker should be registered using the `track_external` method.
373        The `track_external` method should be called before the MemTracker is used. Any tensors created outside the ``MemTracker``
374        and not supplied to the `track_external` method will not be tracked by the ``MemTracker``.
375
376    Example usage:
377
378        .. code-block:: python
379
380            module = ...
381            optimizer = ...
382            inp = ...
383            mem_tracker = MemTracker()
384            mem_tracker.track_external(module, optimizer, inp)
385            with mem_tracker as mt:
386                loss = module(inp)
387                print("After Forward:")
388                mt.display_snapshot("current")
389                loss.backward()
390                optimizer.step()
391                optimizer.zero_grad()
392            mt.display_snapshot("peak")
393            mt.display_modulewise_snapshots(depth = 3, units = "MiB")
394
395    Known Limitations:
396        - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``.
397        - Resizing tensor storages directly by using non-Tensor methods other than using ``torch.Untyped_Storage.resize_``
398          is not tracked. File a Github issue if you have use-cases for this.
399        - If the tensors are not traceable or wrappable subclasses of ``torch.Tensor``, then the tracker does not know how to
400            track their storages. File a Github issue if you have use-cases for this.
401        - During AC in the backward pass there might be misattribution between activation and temp memory, but the peak memory
402          will be tracked accurately. This will be fixed in the next update by hooking intricately with ``torch.uitls.checkpoint``.
403    """
404
405    def __init__(self) -> None:
406        self.memory_tracking = WeakIdKeyDictionary()
407        self._curr_mem_snap: Dict[torch.device, Dict[str, int]] = {}
408        self._peak_mem: Dict[torch.device, int] = {}
409        self._peak_mem_snap: Dict[torch.device, Dict[str, int]] = {}
410        self._param_to_grad_hook_handles = WeakIdKeyDictionary()
411        self._optimizer_hook_handles: Optional[
412            Tuple[RemovableHandle, RemovableHandle]
413        ] = None
414        # Dictionary to store the ``_WeakRefInfo`` instances corresponding to each tensor's storage.
415        self._WINFO = WeakIdKeyDictionary()
416        self._mod_tracker = ModTracker()
417        # This is a general memory tracker which can be used with any ``_RefType`` subclass
418        self._ref_class: Type[_RefType] = _MemRefType
419        # Flags to track if we are in the AC region or optimizer step region
420        self._in_opt: bool = False
421        self._in_ac: bool = False
422        # Weak references to the topmost AC module currently active
423        self._ac_mod: Optional[weakref.ref] = None
424        self._orig_resize = torch.UntypedStorage.resize_
425
426    def _update_snap(
427        self,
428        u_type: _UpdateType,
429        winfo: _WeakRefInfo,
430        old_mem_consumed: Optional[int] = None,
431        old_reftype: Optional[_RefType] = None,
432    ) -> None:
433        # Initialize a flag to track if the total memory might drop to zero after updates.
434        maybe_zero = False
435        # Ensure the device entry exists in the current memory snapshot, initializing if necessary.
436        dev_snap = self._curr_mem_snap.setdefault(
437            winfo.device, dict.fromkeys(self._ref_class, 0)
438        )
439        dev_snap.setdefault(_TOTAL_KEY, 0)
440        # Handle different types of updates based on the update type (`u_type`).
441        if u_type == _UpdateType.ADD:
442            # Increase the memory consumed for the specific reference type and update the total.
443            dev_snap[winfo.reftype] += winfo.mem_consumed
444            dev_snap[_TOTAL_KEY] += winfo.mem_consumed
445        elif u_type == _UpdateType.DEL:
446            # Decrease the memory consumed for the specific reference type and reduce the total.
447            dev_snap[winfo.reftype] -= winfo.mem_consumed
448            dev_snap[_TOTAL_KEY] -= winfo.mem_consumed
449            maybe_zero = True
450        elif u_type == _UpdateType.REF:
451            assert old_reftype is not None
452            # Adjust memory consumption between two reference types within the same device.
453            dev_snap[old_reftype] -= winfo.mem_consumed
454            dev_snap[winfo.reftype] += winfo.mem_consumed
455        elif u_type == _UpdateType.SIZE:
456            assert old_mem_consumed is not None
457            # Adjust the memory consumed for a reference type due to a change in size.
458            change = winfo.mem_consumed - old_mem_consumed
459            dev_snap[winfo.reftype] += change
460            dev_snap[_TOTAL_KEY] += change
461            maybe_zero = True
462        else:
463            raise ValueError(f"Invalid update type: {u_type}")
464        # Check if the total memory for the device has dropped to zero.
465        if maybe_zero:
466            if self._curr_mem_snap[winfo.device][_TOTAL_KEY] == 0:
467                # Remove the device entry from the memory snapshot if the total memory is zero.
468                del self._curr_mem_snap[winfo.device]
469
470    def _update_and_maybe_create_winfos(
471        self,
472        t: torch.Tensor,
473        reftype: _RefType,
474        update_existing: bool = False,
475    ) -> Set[_WeakRefInfo]:
476        sts = _WeakRefInfo.get_untyped_storages(t)
477        winfos = set()
478        for st in sts:
479            # Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary.
480            winfo, _ = self._WINFO.get(st, (None, None))
481            if winfo is not None:
482                # If ``_WeakRefInfo`` exists, check if the reference type needs to be updated.
483                old_reftype = winfo.reftype
484                if old_reftype != reftype:
485                    # Update the reference type and apply changes via ``_update_snap``.
486                    winfo.reftype = reftype
487                    self._update_snap(_UpdateType.REF, winfo, old_reftype=old_reftype)
488                winfos.add(winfo)
489            elif update_existing:
490                # If no existing ``_WeakRefInfo`` is found and update_existing is True, raise an error.
491                raise KeyError("No existing winfo found")
492            else:
493                # If no existing _WeakRefInfo is found and update_existing is False, create a new ``_WeakRefInfo``.
494                winfo, w_st = _WeakRefInfo.create_winfo(
495                    st, t.device, reftype, self._delete_callback
496                )
497                # Store the new ``_WeakRefInfo`` and its weak reference in the tracking dictionary.
498                self._WINFO[st] = (winfo, w_st)
499                # Update the snapshot for the newly added ``_WeakRefInfo``.
500                if winfo.mem_consumed > 0:
501                    self._update_snap(_UpdateType.ADD, winfo)
502                winfos.add(winfo)
503        return winfos
504
505    def _delete_callback(self, winfo: _WeakRefInfo, w_st: weakref.ref) -> None:
506        # Callback to be called when the storage object corresponding to the  ``_WeakRefInfo``
507        # instance is about to be finalized.
508        if winfo.mem_consumed > 0:
509            self._update_snap(_UpdateType.DEL, winfo)
510
511    def _track_resize(self) -> None:
512        # Need to monkey-patch this because ``torch.UntypedStorage.resize_`` is not captured
513        # by ``TorchDispatchMode``.
514        @wraps(self._orig_resize)
515        def resize_(st: torch.UntypedStorage, size: int) -> None:
516            self._orig_resize(st, size)
517            winfo, _ = self._WINFO.get(st, (None, None))
518            if winfo is not None and winfo.size != st.size():
519                old_mem_consumed = winfo.mem_consumed
520                winfo.update_mem_consumed(st)
521                self._update_snap(
522                    _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed
523                )
524
525        torch.UntypedStorage.resize_ = resize_  # type: ignore[method-assign, assignment]
526
527    def _restore_resize(self) -> None:
528        torch.UntypedStorage.resize_ = self._orig_resize  # type: ignore[method-assign]
529
530    def _update_peak_stats(self, peak_state: _State) -> None:
531        # We first capture the current memory snapshot of the current tracker state then,
532        # We step through each of the modules we have tracked so far in ``memory_tracking``
533        #  and check if it is currently active by querying ``_mod_tracker.parents``
534        # If it is active, we update the per device peak memory usage for the module
535        #  corresponding to the ``_State`` which can be ``PEAK_FW`` or ``PEAK_BW``.
536        curr_snap = self._curr_mem_snap
537
538        for mod_stats in self.memory_tracking.values():
539            if mod_stats.mod_fqn in self._mod_tracker.parents:
540                if peak_state in mod_stats.snapshots:
541                    for dev, dev_snap in curr_snap.items():
542                        if mod_stats.local_peak.get(dev, 0) < dev_snap[_TOTAL_KEY]:
543                            mod_stats.local_peak[dev] = dev_snap[_TOTAL_KEY]
544                            mod_stats.snapshots[peak_state][-1][dev] = deepcopy(
545                                dev_snap
546                            )
547
548        for dev, dev_snap in curr_snap.items():
549            if self._peak_mem.get(dev, 0) < dev_snap[_TOTAL_KEY]:
550                self._peak_mem[dev] = dev_snap[_TOTAL_KEY]
551                self._peak_mem_snap[dev] = deepcopy(dev_snap)
552
553    def _track(self, reftype: _RefType, t: torch.Tensor) -> None:
554        # Get the storages of the tensor and check if we have already tracked them.
555        # If yes, then check if the storage size has changed and update the current snapshot.
556        # Else create a new ``_WeakRefInfo`` instance and add it to the dictionary.
557        sts = _WeakRefInfo.get_untyped_storages(t)
558        for st in sts:
559            winfo, _ = self._WINFO.get(st, (None, None))
560            if winfo is not None:
561                if winfo.size != st.size():
562                    old_mem_consumed = winfo.mem_consumed
563                    winfo.update_mem_consumed(st)
564                    self._update_snap(
565                        _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed
566                    )
567                return
568            else:
569                winfo, w_st = _WeakRefInfo.create_winfo(
570                    st, t.device, reftype, self._delete_callback
571                )
572                self._WINFO[st] = (winfo, w_st)
573                # Update the current snapshot for the newly added ``_WeakRefInfo``.
574                if winfo.mem_consumed > 0:
575                    self._update_snap(_UpdateType.ADD, winfo)
576
577    def get_tracker_snapshot(
578        self, type: str = "current"
579    ) -> Dict[torch.device, Dict[str, int]]:
580        """
581        Capture a snapshot of the memory usage breakdown per device, based on the specified type.
582
583        Args:
584            type (str): The type of snapshot to capture. Can be "current" for the current memory usage or "peak" for the
585                        peak memory usage. Defaults to "current".
586        Returns:
587            Dict[torch.device, Dict[str, int]]: A dictionary where each key is a torch.device, and each value is another
588                                                dictionary. This inner dictionary has keys representing memory reference
589                                                types as defined in ``_MemRefType`` and values representing the amount of
590                                                memory consumed in bytes.
591        Raises:
592            ValueError: If an invalid type is specified.
593        """
594        if type == "current":
595            return deepcopy(self._curr_mem_snap)
596        elif type == "peak":
597            return deepcopy(self._peak_mem_snap)
598        else:
599            raise ValueError(f"Invalid type {type}")
600
601    def _track_module_params_and_buffers(
602        self, module: nn.Module, install_grad_hooks: bool = True
603    ) -> Tuple[int, int]:
604        # Track the parameters and buffers of the module if not already tracked.
605        # If the parameters have gradients, track the gradients as well.
606        # If install_grad_hooks is True, install a gradient hook on the parameters
607        #  to track the gradients, if it has not already been installed.
608        # Return the total memory consumed by the parameters and buffers.
609        def _grad_hook(grad: torch.Tensor) -> None:
610            self._update_and_maybe_create_winfos(
611                grad,
612                _MemRefType.GRAD,
613            )
614
615        param_memory = 0
616        for param in module.parameters():
617            winfos = self._update_and_maybe_create_winfos(
618                param,
619                _MemRefType.PARAM,
620            )
621            param_memory += sum(winfo.mem_consumed for winfo in winfos)
622            if param.grad is not None:
623                self._update_and_maybe_create_winfos(
624                    param.grad,
625                    _MemRefType.GRAD,
626                )
627            if (
628                self._param_to_grad_hook_handles.get(param, None) is None
629                and install_grad_hooks
630            ):
631                grad_hook_handle = param.register_hook(_grad_hook)
632                post_acc_grad_hook_handle = param.register_post_accumulate_grad_hook(
633                    lambda p: (_grad_hook(p.grad))
634                )
635                self._param_to_grad_hook_handles[param] = (
636                    grad_hook_handle,
637                    post_acc_grad_hook_handle,
638                )
639        buffer_memory = 0
640        for buffer in module.buffers():
641            winfos = self._update_and_maybe_create_winfos(
642                buffer,
643                _MemRefType.BUFFER,
644            )
645            buffer_memory += sum(winfo.mem_consumed for winfo in winfos)
646        return (param_memory, buffer_memory)
647
648    def _track_inputs_or_outputs(self, args: Any) -> int:
649        # Calculate the memory consumed by the inputs or outputs of the module.
650        input_or_output_memory = 0
651
652        def add_inps_or_outs(t: torch.Tensor) -> None:
653            nonlocal input_or_output_memory
654            sts = _WeakRefInfo.get_untyped_storages(t)
655            for st in sts:
656                winfo, _ = self._WINFO.get(st, (None, None))
657                if winfo is not None:
658                    input_or_output_memory += winfo.mem_consumed
659
660        tree_map_only(torch.Tensor, add_inps_or_outs, args)
661        return input_or_output_memory
662
663    def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None:
664        # This is installed as a pre-fwd user hook with ``ModTracker.`` Based on the following cases we
665        # set the state and capture the memory snapshot for the module.
666        # Case 1: If the module is not in the ``memory_tracking`` dictionary, we track the parameters, buffers,
667        #         input and output memory of the module. Create a new ``_ModMemStats`` instance for the module
668        #         and add it to the ``memory_tracking`` dictionary.
669        # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means
670        #         we are in the AC region. We check if this is the top most module in the AC region. If it is,
671        #         we store a weak reference and set the flag ``_in_ac`` to True.
672        # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means
673        #         this module is called for the second time. If it is a root module, that means we are in the next
674        #         iteration and we error out. If it is not a root module, that means it's a submodule that is being
675        #         used multiple times in the same iteration, which we allow and track.
676        # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module.
677        mod_name = self._mod_tracker.get_known_fqn(module)
678        assert mod_name is not None
679        if module not in self.memory_tracking:
680            mod_stats = _ModMemStats(mod_name)
681            param_mem, buffer_mem = self._track_module_params_and_buffers(
682                module, install_grad_hooks=True
683            )
684            input_mem = self._track_inputs_or_outputs(inputs)
685            mod_stats.parameter_mem = param_mem
686            mod_stats.buffer_mem = buffer_mem
687            mod_stats.input_mem = input_mem
688            self.memory_tracking[module] = mod_stats
689            state = _ModState.PRE_FW
690
691        elif self._mod_tracker.is_bw:
692            mod_stats = self.memory_tracking[module]
693            state = _ModState.PRE_FW_AC
694            if self._ac_mod is None:
695                self._ac_mod = weakref.ref(module)
696                self._in_ac = True
697        else:
698            parents = set(self._mod_tracker.parents) - {mod_name}
699            if len(parents) == 1 and "Global" in parents:
700                raise NotImplementedError(
701                    "MemTracker does not support memory tracking for multiple iterative calls."
702                    " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration"
703                    " or file a github issue if you need this feature."
704                )
705            mod_stats = self.memory_tracking[module]
706            state = _ModState.PRE_FW
707            input_mem = self._track_inputs_or_outputs(inputs)
708            mod_stats.input_mem = input_mem
709
710        mem_snapshot = self.get_tracker_snapshot()
711        if state == _ModState.PRE_FW:
712            mod_stats.local_peak = {
713                dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items()
714            }
715            mod_stats.snapshots.setdefault(_ModState.PEAK_FW, []).append(mem_snapshot)
716        mod_stats.snapshots.setdefault(state, []).append(deepcopy(mem_snapshot))
717
718    def _post_fw_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> None:
719        # This is installed as a post-fwd user hook with ``ModTracker``. Based on the following cases we
720        # set the state and capture the memory snapshot for the module.
721        # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module
722        #         in the AC region, we set the flag ``_in_ac`` to False.
723        # Case 2: This is called in forward so we calculate the output memory
724        #         of the module and update its mod_stats.
725        mod_stats = self.memory_tracking[module]
726        if self._mod_tracker.is_bw:
727            state = _ModState.POST_FW_AC
728            if self._ac_mod is not None and self._ac_mod() is module:
729                self._ac_mod = None
730                self._in_ac = False
731        else:
732            state = _ModState.POST_FW
733            output_mem = self._track_inputs_or_outputs(outputs)
734            mod_stats.output_mem = output_mem
735        mod_stats.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
736
737    def _pre_bw_hook(self, module: nn.Module, args: Any) -> None:
738        # This is installed as a pre-bwd user hook with ``ModTracker``. We set the state and capture the
739        # snapshot for the module. We also initialize the ``local_peak`` and ``PEAK_BW`` snapshot for it.
740        # If the module is None, we skip the hook.
741        # This can happen since this installed inside a multi-grad hook on the module's output tensors
742        # and the module itself may not be alive during backward.
743        if module is None:
744            warnings.warn("Module is None. Skipping PRE_BW hook.", stacklevel=2)
745            return
746        mod_stats = self.memory_tracking[module]
747        mem_snapshot = self.get_tracker_snapshot()
748        mod_stats.local_peak = {
749            dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items()
750        }
751        mod_stats.snapshots.setdefault(_ModState.PEAK_BW, []).append(mem_snapshot)
752        mod_stats.snapshots.setdefault(_ModState.PRE_BW, []).append(
753            deepcopy(mem_snapshot)
754        )
755
756    def _post_bw_hook(self, module: nn.Module, args: Any) -> None:
757        # This is installed as a post-bwd user hook with ``ModTracker``. We set the state and capture the
758        # snapshot for the module if it is not None.
759        # This can happen since this installed inside a multi-grad hook on the module's input tensors
760        # and the module itself may not be alive during backward.
761        if module is None:
762            warnings.warn("Module is None. Skipping POST_BW hook.", stacklevel=2)
763            return
764        mod_stats = self.memory_tracking[module]
765        mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
766            self.get_tracker_snapshot()
767        )
768
769    def _track_optimizer_states(
770        self, reftype: _RefType, optimizer: optim.Optimizer
771    ) -> None:
772        for states in optimizer.state.values():
773            for val in states.values():
774                if isinstance(val, torch.Tensor):
775                    self._update_and_maybe_create_winfos(
776                        val,
777                        reftype,
778                    )
779
780    def _register_global_optimizer_hook(self) -> None:
781        # Register a hook on the optimizer step to track the optimizer states.
782        # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag,
783        # and also tracks any optimizer states that are created during the optimizer step.
784        def _opt_step_pre_hook(
785            optimizer: optim.Optimizer, args: Any, kwargs: Any
786        ) -> None:
787            self._in_opt = True
788
789        def _opt_step_post_hook(
790            optimizer: optim.Optimizer, args: Any, kwargs: Any
791        ) -> None:
792            self._track_optimizer_states(_MemRefType.OPT, optimizer)
793            self._in_opt = False
794
795        self._optimizer_hook_handles = (
796            register_optimizer_step_pre_hook(_opt_step_pre_hook),
797            register_optimizer_step_post_hook(_opt_step_post_hook),
798        )
799
800    def _deregister_param_and_optimizer_hooks(self) -> None:
801        for (
802            grad_hook_handle,
803            post_acc_grad_hook_handle,
804        ) in self._param_to_grad_hook_handles.values():
805            grad_hook_handle.remove()
806            post_acc_grad_hook_handle.remove()
807        self._param_to_grad_hook_handles.clear()
808
809        if self._optimizer_hook_handles is not None:
810            for handle in self._optimizer_hook_handles:
811                handle.remove()
812            self._optimizer_hook_handles = None
813
814    def track_external(
815        self, *external: Union[nn.Module, optim.Optimizer, torch.Tensor]
816    ) -> None:
817        """
818        Track tensors and stateful objects like modules, optimizers etc. that are created outside the MemTracker.
819
820        This method should be called before the ``MemTracker`` is used. Any tensors that are not module parameters, buffers,
821        gradients activations, or optimizer states will be categorized as ``Other``. If you want them categorized with a
822        custom name, please file a GitHub issue. Any tensors created outside the MemTracker and not supplied to this
823        method will not be be tracked by ``MemTracker``.
824
825        Args:
826            *external (Union[nn.Module, optim.Optimizer, torch.Tensor]): The external modules, optimizers, and
827                                                                         tensors to be tracked.
828        """
829        flat_external, _ = tree_flatten(external)
830        for obj in flat_external:
831            if isinstance(obj, torch.Tensor):
832                self._update_and_maybe_create_winfos(
833                    obj,
834                    _MemRefType.OTH,
835                )
836            elif isinstance(obj, torch.nn.Module):
837                self._track_module_params_and_buffers(obj, install_grad_hooks=False)
838            elif isinstance(obj, optim.Optimizer):
839                self._track_optimizer_states(_MemRefType.OPT, obj)
840            else:
841                raise TypeError(
842                    f"Object of type {type(obj)} is not supported for tracking. "
843                    f"Only stateful objects like modules, optimizers, and tensors are supported."
844                )
845
846    def display_snapshot(
847        self, type: str = "current", units: str = "B", tabulate: bool = False
848    ) -> None:
849        """
850        Display the memory usage breakdown snapshot of the tracker based on the specified type and units.
851
852        Keyword args:
853            type (str): The type of snapshot to display. Can be "current" for the current memory usage or "peak" for the
854                        peak memory usage. Defaults to "current".
855            units (str): The units to use for displaying memory usage. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"].
856            tabulate (bool): Whether to display the snapshot in a tabular format. Defaults to False.
857        """
858        snapshot = self.get_tracker_snapshot(type)
859        if tabulate:
860            _print_snapshot_tabular(snapshot, units)
861        else:
862            _print_snapshot(snapshot, units)
863
864    def display_modulewise_snapshots(
865        self, depth: int = 2, units: str = "B", tabulate: bool = False
866    ) -> None:
867        """
868        Print per device memory breakdown snapshot for each module called within MemTracker.
869
870        Snapshots are displayed for the states defined by ``_ModState``.
871        The module hierarchy is displayed up to the specified depth.
872
873        Keyword Args:
874            depth (int, optional): The depth of the module hierarchy to display. Defaults to 2.
875            units (str, optional): The units to use for memory tracking. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"].
876            tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False.
877        """
878
879        def natural_sort_key(s: str) -> List[Union[int, str]]:
880            return [
881                int(text) if text.isdigit() else text.lower()
882                for text in re.split("([0-9]+)", s)
883            ]
884
885        for mod_stats in sorted(
886            self.memory_tracking.values(),
887            key=lambda m_stats: natural_sort_key(m_stats.mod_fqn),
888        ):
889            mod_fqn = mod_stats.mod_fqn
890            mod_depth = mod_fqn.count(".") + 1
891            if mod_depth > depth:
892                continue
893            print(f"Module:  {mod_fqn}")
894            if tabulate:
895                _print_state_snapshots_tabular(mod_stats.snapshots, units)
896            else:
897                _print_state_snapshots(mod_stats.snapshots, units)
898
899    def reset_mod_stats(self) -> None:
900        """
901        Reset all the module memory stats. Clears ``memory_tracking`` dictionary.
902        """
903        self.memory_tracking.clear()
904
905    def __enter__(self) -> "MemTracker":
906        self._register_global_optimizer_hook()
907        self._mod_tracker.register_user_hooks(
908            self._pre_fw_hook,
909            self._post_fw_hook,
910            self._pre_bw_hook,
911            self._post_bw_hook,
912        )
913        self._track_resize()
914        self._peak_mem_snap = self.get_tracker_snapshot()
915        self._peak_mem = {
916            dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items()
917        }
918        self._mod_tracker.__enter__()
919        super().__enter__()
920        return self
921
922    def __exit__(self, *args: Any) -> None:
923        self._deregister_param_and_optimizer_hooks()
924        self._mod_tracker.clear_user_hooks()
925        self._restore_resize()
926        super().__exit__(*args)
927        self._mod_tracker.__exit__(*args)
928
929    def __torch_dispatch__(self, func, types, args=(), kwargs=None):  # type: ignore[no-untyped-def]
930        res = func(*args, **kwargs or {})
931        # If we are tracking an optimizer state, we use the optimizer reference type.
932        # If we are in backward region and not in AC region, we use the backward reference type.
933        # Else we use the forward reference type.
934        if self._in_opt:
935            reftype = _MemRefType.OPT
936        elif self._mod_tracker.is_bw and not self._in_ac:
937            reftype = _MemRefType.TEMP
938        else:
939            reftype = _MemRefType.ACT
940        tree_map_only(torch.Tensor, partial(self._track, reftype), res)
941        peak_state = _ModState.PEAK_BW if self._mod_tracker.is_bw else _ModState.PEAK_FW
942        self._update_peak_stats(peak_state)
943        return res
944