xref: /aosp_15_r20/external/pytorch/torch/utils/flop_counter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: allow-untyped-decorators
3import torch
4from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
5from .module_tracker import ModuleTracker
6from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
7from collections import defaultdict
8from torch.utils._python_dispatch import TorchDispatchMode
9from math import prod
10from functools import wraps
11import warnings
12
13
14
15__all__ = ["FlopCounterMode", "register_flop_formula"]
16
17aten = torch.ops.aten
18
19def get_shape(i):
20    if isinstance(i, torch.Tensor):
21        return i.shape
22    return i
23
24flop_registry: Dict[Any, Any] = {}
25
26def shape_wrapper(f):
27    @wraps(f)
28    def nf(*args, out_val=None, **kwargs):
29        args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val))
30        return f(*args, out_shape=out_shape, **kwargs)
31    return nf
32
33def register_flop_formula(targets, get_raw=False):
34    def register_fun(flop_formula):
35        if not get_raw:
36            flop_formula = shape_wrapper(flop_formula)
37
38        def register(target):
39            if not isinstance(target, torch._ops.OpOverloadPacket):
40                raise ValueError(
41                    f"register_flop_formula(targets): expected each target to be "
42                    f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got "
43                    f"{target} which is of type {type(target)}")
44            if target in flop_registry:
45                raise RuntimeError(f"duplicate registrations for {target}")
46            flop_registry[target] = flop_formula
47
48        # To handle allowing multiple aten_ops at once
49        torch.utils._pytree.tree_map_(register, targets)
50
51        return flop_formula
52
53    return register_fun
54
55@register_flop_formula(aten.mm)
56def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
57    """Count flops for matmul."""
58    # Inputs should be a list of length 2.
59    # Inputs contains the shapes of two matrices.
60    m, k = a_shape
61    k2, n = b_shape
62    assert k == k2
63    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
64    return m * n * 2 * k
65
66@register_flop_formula(aten.addmm)
67def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
68    """Count flops for addmm."""
69    return mm_flop(a_shape, b_shape)
70
71@register_flop_formula(aten.bmm)
72def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
73    """Count flops for the bmm operation."""
74    # Inputs should be a list of length 2.
75    # Inputs contains the shapes of two tensor.
76    b, m, k = a_shape
77    b2, k2, n = b_shape
78    assert b == b2
79    assert k == k2
80    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
81    flop = b * m * n * 2 * k
82    return flop
83
84@register_flop_formula(aten.baddbmm)
85def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
86    """Count flops for the baddbmm operation."""
87    # Inputs should be a list of length 3.
88    # Inputs contains the shapes of three tensors.
89    return bmm_flop(a_shape, b_shape)
90
91
92def conv_flop_count(
93    x_shape: List[int],
94    w_shape: List[int],
95    out_shape: List[int],
96    transposed: bool = False,
97) -> int:
98    """Count flops for convolution.
99
100    Note only multiplication is
101    counted. Computation for bias are ignored.
102    Flops for a transposed convolution are calculated as
103    flops = (x_shape[2:] * prod(w_shape) * batch_size).
104    Args:
105        x_shape (list(int)): The input shape before convolution.
106        w_shape (list(int)): The filter shape.
107        out_shape (list(int)): The output shape after convolution.
108        transposed (bool): is the convolution transposed
109    Returns:
110        int: the number of flops
111    """
112
113    batch_size = x_shape[0]
114    conv_shape = (x_shape if transposed else out_shape)[2:]
115    c_out, c_in, *filter_size = w_shape
116
117    """
118    General idea here is that for a regular conv, for each point in the output
119    spatial dimension we convolve the filter with something (hence
120    `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
121    1. batch_size, 2. the cross product of input and weight channels.
122
123    For the transpose, it's not each point in the *output* spatial dimension but
124    each point in the *input* spatial dimension.
125    """
126    # NB(chilli): I don't think this properly accounts for padding :think:
127    # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
128    flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
129    return flop
130
131@register_flop_formula([aten.convolution, aten._convolution])
132def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
133    """Count flops for convolution."""
134    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
135
136
137@register_flop_formula(aten.convolution_backward)
138def conv_backward_flop(
139        grad_out_shape,
140        x_shape,
141        w_shape,
142        _bias,
143        _stride,
144        _padding,
145        _dilation,
146        transposed,
147        _output_padding,
148        _groups,
149        output_mask,
150        out_shape) -> int:
151
152    def t(shape):
153        return [shape[1], shape[0]] + list(shape[2:])
154    flop_count = 0
155
156    """
157    Let's say we have a regular 1D conv
158    {A, B, C} [inp]
159    {i, j} [weight]
160    => (conv)
161    {Ai + Bj, Bi + Cj} [out]
162
163    And as a reminder, the transposed conv of the above is
164    => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
165
166    For the backwards of conv, we now have
167    {D, E} [grad_out]
168    {A, B, C} [inp]
169    {i, j} [weight]
170
171    # grad_inp as conv_transpose(grad_out, weight)
172    Let's first compute grad_inp. To do so, we can simply look at all the
173    multiplications that each element of inp is involved in. For example, A is
174    only involved in the first element of the output (and thus only depends upon
175    D in grad_out), and C is only involved in the last element of the output
176    (and thus only depends upon E in grad_out)
177
178    {Di, Dj + Ei, Ej} [grad_inp]
179
180    Note that this corresponds to the below conv_transpose. This gives us the
181    output_mask[0] branch, which is grad_inp.
182
183    {D, E} [inp (grad_out)]
184    {i, j} [weight]
185    => (conv_transpose)
186    {Di, Dj + Ei, Ej} [out (grad_inp)]
187
188    I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
189    weight) as an exercise for the reader.
190
191    # grad_weight as conv(inp, grad_out)
192    To compute grad_weight, we again look at the terms in the output, which as
193    a reminder is:
194    => {Ai + Bj, Bi + Cj} [out]
195    => {D, E} [grad_out]
196    If we manually compute the gradient for the weights, we see it's
197    {AD + BE, BD + CE} [grad_weight]
198
199    This corresponds to the below conv
200    {A, B, C} [inp]
201    {D, E} [weight (grad_out)]
202    => (conv)
203    {AD + BE, BD + CE} [out (grad_weight)]
204
205    # grad_weight of transposed conv as conv(grad_out, inp)
206    As a reminder, the terms of the output of a transposed conv are:
207    => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
208    => {D, E, F, G} [grad_out]
209
210    Manually computing the gradient for the weights, we see it's
211    {AD + BE + CF, AE + BF + CG} [grad_weight]
212
213    This corresponds to the below conv
214    {D, E, F, G} [inp (grad_out)]
215    {A, B, C} [weight (inp)]
216    => (conv)
217    {AD + BE + CF, AE + BF + CG} [out (grad_weight)]
218
219    For the full backwards formula, there are also some details involving
220    transpose of the batch/channel dimensions and groups, but I skip those for
221    the sake of brevity (and they're pretty similar to matmul backwards)
222
223    Check [conv backwards decomposition as conv forwards]
224    """
225    # grad_inp as conv_transpose(grad_out, weight)
226    if output_mask[0]:
227        grad_input_shape = get_shape(out_shape[0])
228        flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
229
230    if output_mask[1]:
231        grad_weight_shape = get_shape(out_shape[1])
232        if transposed:
233            # grad_weight of transposed conv as conv(grad_out, inp)
234            flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
235        else:
236            # grad_weight as conv(inp, grad_out)
237            flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
238
239    return flop_count
240
241def sdpa_flop_count(query_shape, key_shape, value_shape):
242    """
243    Count flops for self-attention.
244
245    NB: We can assume that value_shape == key_shape
246    """
247    b, h, s_q, d_q = query_shape
248    _b2, _h2, s_k, _d2 = key_shape
249    _b3, _h3, _s3, d_v = value_shape
250    assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
251    total_flops = 0
252    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
253    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
254    # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
255    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
256    return total_flops
257
258
259@register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention])
260def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
261    """Count flops for self-attention."""
262    # NB: We aren't accounting for causal attention here
263    return sdpa_flop_count(query_shape, key_shape, value_shape)
264
265
266def _offsets_to_lengths(offsets, max_len):
267    """
268    If the offsets tensor is fake, then we don't know the actual lengths.
269    In that case, we can just assume the worst case; each batch has max length.
270    """
271    from torch._subclasses.fake_tensor import FakeTensor
272    from torch._subclasses.functional_tensor import FunctionalTensor
273    if not isinstance(offsets, (FakeTensor, FunctionalTensor)):
274        return offsets.diff().tolist()
275    return [max_len] * (offsets.size(0) - 1)
276
277
278def _unpack_flash_attention_nested_shapes(
279    *,
280    query,
281    key,
282    value,
283    grad_out=None,
284    cum_seq_q,
285    cum_seq_k,
286    max_q,
287    max_k,
288) -> Iterator[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Optional[Tuple[int, ...]]]]:
289    """
290    Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for
291    NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
292    each batch element.
293
294    In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
295    """
296    if cum_seq_q is not None:
297        # This means we should be dealing with a Nested Jagged Tensor query.
298        # The inputs will have shape                  (sum(sequence len), heads, dimension)
299        # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
300        # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
301        # So the flops calculation in this case is an overestimate of the actual flops.
302        assert len(key.shape) == 3
303        assert len(value.shape) == 3
304        assert grad_out is None or grad_out.shape == query.shape
305        _, h_q, d_q = query.shape
306        _, h_k, d_k = key.shape
307        _, h_v, d_v = value.shape
308        assert cum_seq_q is not None
309        assert cum_seq_k is not None
310        assert cum_seq_q.shape == cum_seq_k.shape
311        seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
312        seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
313        for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
314            new_query_shape = (1, h_q, seq_q_len, d_q)
315            new_key_shape = (1, h_k, seq_k_len, d_k)
316            new_value_shape = (1, h_v, seq_k_len, d_v)
317            new_grad_out_shape = new_query_shape if grad_out is not None else None
318            yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
319        return
320
321    yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
322
323
324def _unpack_efficient_attention_nested_shapes(
325    *,
326    query,
327    key,
328    value,
329    grad_out=None,
330    cu_seqlens_q,
331    cu_seqlens_k,
332    max_seqlen_q,
333    max_seqlen_k,
334) -> Iterator[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Optional[Tuple[int, ...]]]]:
335    """
336    Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for
337    NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
338    each batch element.
339
340    In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
341    """
342    if cu_seqlens_q is not None:
343        # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention.
344        #
345        # This means we should be dealing with a Nested Jagged Tensor query.
346        # The inputs will have shape                  (sum(sequence len), heads, dimension)
347        # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
348        # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
349        # So the flops calculation in this case is an overestimate of the actual flops.
350        assert len(key.shape) == 4
351        assert len(value.shape) == 4
352        assert grad_out is None or grad_out.shape == query.shape
353        _, _, h_q, d_q = query.shape
354        _, _, h_k, d_k = key.shape
355        _, _, h_v, d_v = value.shape
356        assert cu_seqlens_q is not None
357        assert cu_seqlens_k is not None
358        assert cu_seqlens_q.shape == cu_seqlens_k.shape
359        seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
360        seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
361        for len_q, len_k in zip(seqlens_q, seqlens_k):
362            new_query_shape = (1, h_q, len_q, d_q)
363            new_key_shape = (1, h_k, len_k, d_k)
364            new_value_shape = (1, h_v, len_k, d_v)
365            new_grad_out_shape = new_query_shape if grad_out is not None else None
366            yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
367        return
368
369    yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
370
371
372@register_flop_formula(aten._flash_attention_forward, get_raw=True)
373def _flash_attention_forward_flop(
374    query,
375    key,
376    value,
377    cum_seq_q,
378    cum_seq_k,
379    max_q,
380    max_k,
381    *args,
382    out_shape=None,
383    **kwargs
384) -> int:
385    """Count flops for self-attention."""
386    # NB: We aren't accounting for causal attention here
387    # in case this is a nested tensor, we unpack the individual batch elements
388    # and then sum the flops per batch element
389    sizes = _unpack_flash_attention_nested_shapes(
390        query=query,
391        key=key,
392        value=value,
393        cum_seq_q=cum_seq_q,
394        cum_seq_k=cum_seq_k,
395        max_q=max_q,
396        max_k=max_k,
397    )
398    return sum(
399        sdpa_flop_count(query_shape, key_shape, value_shape)
400        for query_shape, key_shape, value_shape, _ in sizes
401    )
402
403
404@register_flop_formula(aten._efficient_attention_forward, get_raw=True)
405def _efficient_attention_forward_flop(
406    query,
407    key,
408    value,
409    bias,
410    cu_seqlens_q,
411    cu_seqlens_k,
412    max_seqlen_q,
413    max_seqlen_k,
414    *args,
415    **kwargs
416) -> int:
417    """Count flops for self-attention."""
418    # NB: We aren't accounting for causal attention here
419    # in case this is a nested tensor, we unpack the individual batch elements
420    # and then sum the flops per batch element
421    sizes = _unpack_efficient_attention_nested_shapes(
422        query=query,
423        key=key,
424        value=value,
425        cu_seqlens_q=cu_seqlens_q,
426        cu_seqlens_k=cu_seqlens_k,
427        max_seqlen_q=max_seqlen_q,
428        max_seqlen_k=max_seqlen_k,
429    )
430    return sum(
431        sdpa_flop_count(query_shape, key_shape, value_shape)
432        for query_shape, key_shape, value_shape, _ in sizes
433    )
434
435
436def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
437    total_flops = 0
438    b, h, s_q, d_q = query_shape
439    _b2, _h2, s_k, _d2 = key_shape
440    _b3, _h3, _s3, d_v = value_shape
441    _b4, _h4, _s4, _d4 = grad_out_shape
442    assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
443    assert d_v == _d4 and s_k == _s3 and s_q == _s4
444    total_flops = 0
445    # Step 1: We recompute the scores matrix.
446    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
447    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
448
449    # Step 2: We propagate the gradients through the score @ v operation.
450    # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
451    total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
452    # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
453    total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
454
455    # Step 3: We propagate th gradients through the k @ v operation
456    # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
457    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
458    # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
459    total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
460    return total_flops
461
462
463@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward])
464def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
465    """Count flops for self-attention backward."""
466    return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
467
468@register_flop_formula(aten._flash_attention_backward, get_raw=True)
469def _flash_attention_backward_flop(
470    grad_out,
471    query,
472    key,
473    value,
474    out,  # named _out_shape to avoid kwarg collision with out_shape created in wrapper
475    logsumexp,
476    cum_seq_q,
477    cum_seq_k,
478    max_q,
479    max_k,
480    *args,
481    **kwargs,
482) -> int:
483    # in case this is a nested tensor, we unpack the individual batch elements
484    # and then sum the flops per batch element
485    shapes = _unpack_flash_attention_nested_shapes(
486        query=query,
487        key=key,
488        value=value,
489        grad_out=grad_out,
490        cum_seq_q=cum_seq_q,
491        cum_seq_k=cum_seq_k,
492        max_q=max_q,
493        max_k=max_k,
494    )
495    return sum(
496        sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
497        for query_shape, key_shape, value_shape, grad_out_shape in shapes
498    )
499
500
501@register_flop_formula(aten._efficient_attention_backward, get_raw=True)
502def _efficient_attention_backward_flop(
503    grad_out,
504    query,
505    key,
506    value,
507    bias,
508    out,  # named _out to avoid kwarg collision with out created in wrapper
509    cu_seqlens_q,
510    cu_seqlens_k,
511    max_seqlen_q,
512    max_seqlen_k,
513    *args,
514    **kwargs,
515) -> int:
516    # in case this is a nested tensor, we unpack the individual batch elements
517    # and then sum the flops per batch element
518    shapes = _unpack_efficient_attention_nested_shapes(
519        query=query,
520        key=key,
521        value=value,
522        grad_out=grad_out,
523        cu_seqlens_q=cu_seqlens_q,
524        cu_seqlens_k=cu_seqlens_k,
525        max_seqlen_q=max_seqlen_q,
526        max_seqlen_k=max_seqlen_k,
527    )
528    return sum(
529        sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
530        for query_shape, key_shape, value_shape, grad_out_shape in shapes
531    )
532
533
534flop_registry = {
535    aten.mm: mm_flop,
536    aten.addmm: addmm_flop,
537    aten.bmm: bmm_flop,
538    aten.baddbmm: baddbmm_flop,
539    aten.convolution: conv_flop,
540    aten._convolution: conv_flop,
541    aten.convolution_backward: conv_backward_flop,
542    aten._scaled_dot_product_efficient_attention: sdpa_flop,
543    aten._scaled_dot_product_flash_attention: sdpa_flop,
544    aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
545    aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
546    aten._flash_attention_forward: _flash_attention_forward_flop,
547    aten._efficient_attention_forward: _efficient_attention_forward_flop,
548    aten._flash_attention_backward: _flash_attention_backward_flop,
549    aten._efficient_attention_backward: _efficient_attention_backward_flop,
550}
551
552def normalize_tuple(x):
553    if not isinstance(x, tuple):
554        return (x,)
555    return x
556
557
558# Define the suffixes for different orders of magnitude
559suffixes = ["", "K", "M", "B", "T"]
560# Thanks BingChat!
561def get_suffix_str(number):
562    # Find the index of the appropriate suffix based on the number of digits
563    # with some additional overflow.
564    # i.e. 1.01B should be displayed as 1001M, not 1.001B
565    index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
566    return suffixes[index]
567
568def convert_num_with_suffix(number, suffix):
569    index = suffixes.index(suffix)
570    # Divide the number by 1000^index and format it to two decimal places
571    value = f"{number / 1000 ** index:.3f}"
572    # Return the value and the suffix as a string
573    return value + suffixes[index]
574
575def convert_to_percent_str(num, denom):
576    if denom == 0:
577        return "0%"
578    return f"{num / denom:.2%}"
579
580def _pytreeify_preserve_structure(f):
581    @wraps(f)
582    def nf(args):
583        flat_args, spec = tree_flatten(args)
584        out = f(*flat_args)
585        return tree_unflatten(out, spec)
586
587    return nf
588
589
590class FlopCounterMode(TorchDispatchMode):
591    """
592    ``FlopCounterMode`` is a context manager that counts the number of flops within its context.
593
594    It does this using a ``TorchDispatchMode``.
595
596    It also supports hierarchical output by passing a module (or list of
597    modules) to FlopCounterMode on construction. If you do not need hierarchical
598    output, you do not need to use it with a module.
599
600    Example usage
601
602    .. code-block:: python
603
604        mod = ...
605        with FlopCounterMode(mod) as flop_counter:
606            mod.sum().backward()
607
608    """
609
610    def __init__(
611            self,
612            mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
613            depth: int = 2,
614            display: bool = True,
615            custom_mapping: Optional[Dict[Any, Any]] = None):
616        super().__init__()
617        self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
618        self.depth = depth
619        self.display = display
620        if custom_mapping is None:
621            custom_mapping = {}
622        if mods is not None:
623            warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
624        self.flop_registry = {
625            **flop_registry,
626            **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
627        }
628        self.mod_tracker = ModuleTracker()
629
630    def get_total_flops(self) -> int:
631        return sum(self.flop_counts['Global'].values())
632
633    def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
634        """Return the flop counts as a dictionary of dictionaries.
635
636        The outer
637        dictionary is keyed by module name, and the inner dictionary is keyed by
638        operation name.
639
640        Returns:
641            Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
642        """
643        return {k: dict(v) for k, v in self.flop_counts.items()}
644
645    def get_table(self, depth=None):
646        if depth is None:
647            depth = self.depth
648        if depth is None:
649            depth = 999999
650
651        import tabulate
652        tabulate.PRESERVE_WHITESPACE = True
653        header = ["Module", "FLOP", "% Total"]
654        values = []
655        global_flops = self.get_total_flops()
656        global_suffix = get_suffix_str(global_flops)
657        is_global_subsumed = False
658
659        def process_mod(mod_name, depth):
660            nonlocal is_global_subsumed
661
662            total_flops = sum(self.flop_counts[mod_name].values())
663
664            is_global_subsumed |= total_flops >= global_flops
665
666            padding = " " * depth
667            values = []
668            values.append([
669                padding + mod_name,
670                convert_num_with_suffix(total_flops, global_suffix),
671                convert_to_percent_str(total_flops, global_flops)
672            ])
673            for k, v in self.flop_counts[mod_name].items():
674                values.append([
675                    padding + " - " + str(k),
676                    convert_num_with_suffix(v, global_suffix),
677                    convert_to_percent_str(v, global_flops)
678                ])
679            return values
680
681        for mod in sorted(self.flop_counts.keys()):
682            if mod == 'Global':
683                continue
684            mod_depth = mod.count(".") + 1
685            if mod_depth > depth:
686                continue
687
688            cur_values = process_mod(mod, mod_depth - 1)
689            values.extend(cur_values)
690
691        # We do a bit of messing around here to only output the "Global" value
692        # if there are any FLOPs in there that aren't already fully contained by
693        # a module.
694        if 'Global' in self.flop_counts and not is_global_subsumed:
695            for idx in range(len(values)):
696                values[idx][0] = " " + values[idx][0]
697
698            values = process_mod('Global', 0) + values
699
700        if len(values) == 0:
701            values = [["Global", "0", "0%"]]
702
703        return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
704
705    def __enter__(self):
706        self.flop_counts.clear()
707        self.mod_tracker.__enter__()
708        super().__enter__()
709        return self
710
711    def __exit__(self, *args):
712        super().__exit__(*args)
713        self.mod_tracker.__exit__()
714        if self.display:
715            print(self.get_table(self.depth))
716
717    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
718        kwargs = kwargs if kwargs else {}
719        out = func(*args, **kwargs)
720        return self._count_flops(func._overloadpacket, out, args, kwargs)
721
722    def _count_flops(self, func_packet, out, args, kwargs):
723        if func_packet in self.flop_registry:
724            flop_count_func = self.flop_registry[func_packet]
725            flop_count = flop_count_func(*args, **kwargs, out_val=out)  # type: ignore[operator]
726            for par in set(self.mod_tracker.parents):
727                self.flop_counts[par][func_packet] += flop_count
728
729        return out
730