1# mypy: allow-untyped-defs 2# mypy: allow-untyped-decorators 3import torch 4from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten 5from .module_tracker import ModuleTracker 6from typing import List, Any, Dict, Optional, Union, Tuple, Iterator 7from collections import defaultdict 8from torch.utils._python_dispatch import TorchDispatchMode 9from math import prod 10from functools import wraps 11import warnings 12 13 14 15__all__ = ["FlopCounterMode", "register_flop_formula"] 16 17aten = torch.ops.aten 18 19def get_shape(i): 20 if isinstance(i, torch.Tensor): 21 return i.shape 22 return i 23 24flop_registry: Dict[Any, Any] = {} 25 26def shape_wrapper(f): 27 @wraps(f) 28 def nf(*args, out_val=None, **kwargs): 29 args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) 30 return f(*args, out_shape=out_shape, **kwargs) 31 return nf 32 33def register_flop_formula(targets, get_raw=False): 34 def register_fun(flop_formula): 35 if not get_raw: 36 flop_formula = shape_wrapper(flop_formula) 37 38 def register(target): 39 if not isinstance(target, torch._ops.OpOverloadPacket): 40 raise ValueError( 41 f"register_flop_formula(targets): expected each target to be " 42 f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got " 43 f"{target} which is of type {type(target)}") 44 if target in flop_registry: 45 raise RuntimeError(f"duplicate registrations for {target}") 46 flop_registry[target] = flop_formula 47 48 # To handle allowing multiple aten_ops at once 49 torch.utils._pytree.tree_map_(register, targets) 50 51 return flop_formula 52 53 return register_fun 54 55@register_flop_formula(aten.mm) 56def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int: 57 """Count flops for matmul.""" 58 # Inputs should be a list of length 2. 59 # Inputs contains the shapes of two matrices. 60 m, k = a_shape 61 k2, n = b_shape 62 assert k == k2 63 # NB(chilli): Should be 2 * k - 1 technically for FLOPs. 64 return m * n * 2 * k 65 66@register_flop_formula(aten.addmm) 67def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: 68 """Count flops for addmm.""" 69 return mm_flop(a_shape, b_shape) 70 71@register_flop_formula(aten.bmm) 72def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int: 73 """Count flops for the bmm operation.""" 74 # Inputs should be a list of length 2. 75 # Inputs contains the shapes of two tensor. 76 b, m, k = a_shape 77 b2, k2, n = b_shape 78 assert b == b2 79 assert k == k2 80 # NB(chilli): Should be 2 * k - 1 technically for FLOPs. 81 flop = b * m * n * 2 * k 82 return flop 83 84@register_flop_formula(aten.baddbmm) 85def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: 86 """Count flops for the baddbmm operation.""" 87 # Inputs should be a list of length 3. 88 # Inputs contains the shapes of three tensors. 89 return bmm_flop(a_shape, b_shape) 90 91 92def conv_flop_count( 93 x_shape: List[int], 94 w_shape: List[int], 95 out_shape: List[int], 96 transposed: bool = False, 97) -> int: 98 """Count flops for convolution. 99 100 Note only multiplication is 101 counted. Computation for bias are ignored. 102 Flops for a transposed convolution are calculated as 103 flops = (x_shape[2:] * prod(w_shape) * batch_size). 104 Args: 105 x_shape (list(int)): The input shape before convolution. 106 w_shape (list(int)): The filter shape. 107 out_shape (list(int)): The output shape after convolution. 108 transposed (bool): is the convolution transposed 109 Returns: 110 int: the number of flops 111 """ 112 113 batch_size = x_shape[0] 114 conv_shape = (x_shape if transposed else out_shape)[2:] 115 c_out, c_in, *filter_size = w_shape 116 117 """ 118 General idea here is that for a regular conv, for each point in the output 119 spatial dimension we convolve the filter with something (hence 120 `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by 121 1. batch_size, 2. the cross product of input and weight channels. 122 123 For the transpose, it's not each point in the *output* spatial dimension but 124 each point in the *input* spatial dimension. 125 """ 126 # NB(chilli): I don't think this properly accounts for padding :think: 127 # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs. 128 flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2 129 return flop 130 131@register_flop_formula([aten.convolution, aten._convolution]) 132def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: 133 """Count flops for convolution.""" 134 return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) 135 136 137@register_flop_formula(aten.convolution_backward) 138def conv_backward_flop( 139 grad_out_shape, 140 x_shape, 141 w_shape, 142 _bias, 143 _stride, 144 _padding, 145 _dilation, 146 transposed, 147 _output_padding, 148 _groups, 149 output_mask, 150 out_shape) -> int: 151 152 def t(shape): 153 return [shape[1], shape[0]] + list(shape[2:]) 154 flop_count = 0 155 156 """ 157 Let's say we have a regular 1D conv 158 {A, B, C} [inp] 159 {i, j} [weight] 160 => (conv) 161 {Ai + Bj, Bi + Cj} [out] 162 163 And as a reminder, the transposed conv of the above is 164 => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out] 165 166 For the backwards of conv, we now have 167 {D, E} [grad_out] 168 {A, B, C} [inp] 169 {i, j} [weight] 170 171 # grad_inp as conv_transpose(grad_out, weight) 172 Let's first compute grad_inp. To do so, we can simply look at all the 173 multiplications that each element of inp is involved in. For example, A is 174 only involved in the first element of the output (and thus only depends upon 175 D in grad_out), and C is only involved in the last element of the output 176 (and thus only depends upon E in grad_out) 177 178 {Di, Dj + Ei, Ej} [grad_inp] 179 180 Note that this corresponds to the below conv_transpose. This gives us the 181 output_mask[0] branch, which is grad_inp. 182 183 {D, E} [inp (grad_out)] 184 {i, j} [weight] 185 => (conv_transpose) 186 {Di, Dj + Ei, Ej} [out (grad_inp)] 187 188 I leave the fact that grad_inp for a transposed conv is just conv(grad_out, 189 weight) as an exercise for the reader. 190 191 # grad_weight as conv(inp, grad_out) 192 To compute grad_weight, we again look at the terms in the output, which as 193 a reminder is: 194 => {Ai + Bj, Bi + Cj} [out] 195 => {D, E} [grad_out] 196 If we manually compute the gradient for the weights, we see it's 197 {AD + BE, BD + CE} [grad_weight] 198 199 This corresponds to the below conv 200 {A, B, C} [inp] 201 {D, E} [weight (grad_out)] 202 => (conv) 203 {AD + BE, BD + CE} [out (grad_weight)] 204 205 # grad_weight of transposed conv as conv(grad_out, inp) 206 As a reminder, the terms of the output of a transposed conv are: 207 => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out] 208 => {D, E, F, G} [grad_out] 209 210 Manually computing the gradient for the weights, we see it's 211 {AD + BE + CF, AE + BF + CG} [grad_weight] 212 213 This corresponds to the below conv 214 {D, E, F, G} [inp (grad_out)] 215 {A, B, C} [weight (inp)] 216 => (conv) 217 {AD + BE + CF, AE + BF + CG} [out (grad_weight)] 218 219 For the full backwards formula, there are also some details involving 220 transpose of the batch/channel dimensions and groups, but I skip those for 221 the sake of brevity (and they're pretty similar to matmul backwards) 222 223 Check [conv backwards decomposition as conv forwards] 224 """ 225 # grad_inp as conv_transpose(grad_out, weight) 226 if output_mask[0]: 227 grad_input_shape = get_shape(out_shape[0]) 228 flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed) 229 230 if output_mask[1]: 231 grad_weight_shape = get_shape(out_shape[1]) 232 if transposed: 233 # grad_weight of transposed conv as conv(grad_out, inp) 234 flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False) 235 else: 236 # grad_weight as conv(inp, grad_out) 237 flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False) 238 239 return flop_count 240 241def sdpa_flop_count(query_shape, key_shape, value_shape): 242 """ 243 Count flops for self-attention. 244 245 NB: We can assume that value_shape == key_shape 246 """ 247 b, h, s_q, d_q = query_shape 248 _b2, _h2, s_k, _d2 = key_shape 249 _b3, _h3, _s3, d_v = value_shape 250 assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2 251 total_flops = 0 252 # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] 253 total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) 254 # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v] 255 total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v)) 256 return total_flops 257 258 259@register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention]) 260def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: 261 """Count flops for self-attention.""" 262 # NB: We aren't accounting for causal attention here 263 return sdpa_flop_count(query_shape, key_shape, value_shape) 264 265 266def _offsets_to_lengths(offsets, max_len): 267 """ 268 If the offsets tensor is fake, then we don't know the actual lengths. 269 In that case, we can just assume the worst case; each batch has max length. 270 """ 271 from torch._subclasses.fake_tensor import FakeTensor 272 from torch._subclasses.functional_tensor import FunctionalTensor 273 if not isinstance(offsets, (FakeTensor, FunctionalTensor)): 274 return offsets.diff().tolist() 275 return [max_len] * (offsets.size(0) - 1) 276 277 278def _unpack_flash_attention_nested_shapes( 279 *, 280 query, 281 key, 282 value, 283 grad_out=None, 284 cum_seq_q, 285 cum_seq_k, 286 max_q, 287 max_k, 288) -> Iterator[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Optional[Tuple[int, ...]]]]: 289 """ 290 Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for 291 NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for 292 each batch element. 293 294 In the case that this isn't a NestedTensor kernel, then it just yields the original shapes. 295 """ 296 if cum_seq_q is not None: 297 # This means we should be dealing with a Nested Jagged Tensor query. 298 # The inputs will have shape (sum(sequence len), heads, dimension) 299 # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension) 300 # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension) 301 # So the flops calculation in this case is an overestimate of the actual flops. 302 assert len(key.shape) == 3 303 assert len(value.shape) == 3 304 assert grad_out is None or grad_out.shape == query.shape 305 _, h_q, d_q = query.shape 306 _, h_k, d_k = key.shape 307 _, h_v, d_v = value.shape 308 assert cum_seq_q is not None 309 assert cum_seq_k is not None 310 assert cum_seq_q.shape == cum_seq_k.shape 311 seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q) 312 seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k) 313 for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths): 314 new_query_shape = (1, h_q, seq_q_len, d_q) 315 new_key_shape = (1, h_k, seq_k_len, d_k) 316 new_value_shape = (1, h_v, seq_k_len, d_v) 317 new_grad_out_shape = new_query_shape if grad_out is not None else None 318 yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape 319 return 320 321 yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None 322 323 324def _unpack_efficient_attention_nested_shapes( 325 *, 326 query, 327 key, 328 value, 329 grad_out=None, 330 cu_seqlens_q, 331 cu_seqlens_k, 332 max_seqlen_q, 333 max_seqlen_k, 334) -> Iterator[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Optional[Tuple[int, ...]]]]: 335 """ 336 Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for 337 NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for 338 each batch element. 339 340 In the case that this isn't a NestedTensor kernel, then it just yields the original shapes. 341 """ 342 if cu_seqlens_q is not None: 343 # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention. 344 # 345 # This means we should be dealing with a Nested Jagged Tensor query. 346 # The inputs will have shape (sum(sequence len), heads, dimension) 347 # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension) 348 # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension) 349 # So the flops calculation in this case is an overestimate of the actual flops. 350 assert len(key.shape) == 4 351 assert len(value.shape) == 4 352 assert grad_out is None or grad_out.shape == query.shape 353 _, _, h_q, d_q = query.shape 354 _, _, h_k, d_k = key.shape 355 _, _, h_v, d_v = value.shape 356 assert cu_seqlens_q is not None 357 assert cu_seqlens_k is not None 358 assert cu_seqlens_q.shape == cu_seqlens_k.shape 359 seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q) 360 seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k) 361 for len_q, len_k in zip(seqlens_q, seqlens_k): 362 new_query_shape = (1, h_q, len_q, d_q) 363 new_key_shape = (1, h_k, len_k, d_k) 364 new_value_shape = (1, h_v, len_k, d_v) 365 new_grad_out_shape = new_query_shape if grad_out is not None else None 366 yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape 367 return 368 369 yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None 370 371 372@register_flop_formula(aten._flash_attention_forward, get_raw=True) 373def _flash_attention_forward_flop( 374 query, 375 key, 376 value, 377 cum_seq_q, 378 cum_seq_k, 379 max_q, 380 max_k, 381 *args, 382 out_shape=None, 383 **kwargs 384) -> int: 385 """Count flops for self-attention.""" 386 # NB: We aren't accounting for causal attention here 387 # in case this is a nested tensor, we unpack the individual batch elements 388 # and then sum the flops per batch element 389 sizes = _unpack_flash_attention_nested_shapes( 390 query=query, 391 key=key, 392 value=value, 393 cum_seq_q=cum_seq_q, 394 cum_seq_k=cum_seq_k, 395 max_q=max_q, 396 max_k=max_k, 397 ) 398 return sum( 399 sdpa_flop_count(query_shape, key_shape, value_shape) 400 for query_shape, key_shape, value_shape, _ in sizes 401 ) 402 403 404@register_flop_formula(aten._efficient_attention_forward, get_raw=True) 405def _efficient_attention_forward_flop( 406 query, 407 key, 408 value, 409 bias, 410 cu_seqlens_q, 411 cu_seqlens_k, 412 max_seqlen_q, 413 max_seqlen_k, 414 *args, 415 **kwargs 416) -> int: 417 """Count flops for self-attention.""" 418 # NB: We aren't accounting for causal attention here 419 # in case this is a nested tensor, we unpack the individual batch elements 420 # and then sum the flops per batch element 421 sizes = _unpack_efficient_attention_nested_shapes( 422 query=query, 423 key=key, 424 value=value, 425 cu_seqlens_q=cu_seqlens_q, 426 cu_seqlens_k=cu_seqlens_k, 427 max_seqlen_q=max_seqlen_q, 428 max_seqlen_k=max_seqlen_k, 429 ) 430 return sum( 431 sdpa_flop_count(query_shape, key_shape, value_shape) 432 for query_shape, key_shape, value_shape, _ in sizes 433 ) 434 435 436def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape): 437 total_flops = 0 438 b, h, s_q, d_q = query_shape 439 _b2, _h2, s_k, _d2 = key_shape 440 _b3, _h3, _s3, d_v = value_shape 441 _b4, _h4, _s4, _d4 = grad_out_shape 442 assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2 443 assert d_v == _d4 and s_k == _s3 and s_q == _s4 444 total_flops = 0 445 # Step 1: We recompute the scores matrix. 446 # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] 447 total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) 448 449 # Step 2: We propagate the gradients through the score @ v operation. 450 # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k] 451 total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k)) 452 # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v] 453 total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v)) 454 455 # Step 3: We propagate th gradients through the k @ v operation 456 # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q] 457 total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q)) 458 # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k] 459 total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k)) 460 return total_flops 461 462 463@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward]) 464def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: 465 """Count flops for self-attention backward.""" 466 return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) 467 468@register_flop_formula(aten._flash_attention_backward, get_raw=True) 469def _flash_attention_backward_flop( 470 grad_out, 471 query, 472 key, 473 value, 474 out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper 475 logsumexp, 476 cum_seq_q, 477 cum_seq_k, 478 max_q, 479 max_k, 480 *args, 481 **kwargs, 482) -> int: 483 # in case this is a nested tensor, we unpack the individual batch elements 484 # and then sum the flops per batch element 485 shapes = _unpack_flash_attention_nested_shapes( 486 query=query, 487 key=key, 488 value=value, 489 grad_out=grad_out, 490 cum_seq_q=cum_seq_q, 491 cum_seq_k=cum_seq_k, 492 max_q=max_q, 493 max_k=max_k, 494 ) 495 return sum( 496 sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) 497 for query_shape, key_shape, value_shape, grad_out_shape in shapes 498 ) 499 500 501@register_flop_formula(aten._efficient_attention_backward, get_raw=True) 502def _efficient_attention_backward_flop( 503 grad_out, 504 query, 505 key, 506 value, 507 bias, 508 out, # named _out to avoid kwarg collision with out created in wrapper 509 cu_seqlens_q, 510 cu_seqlens_k, 511 max_seqlen_q, 512 max_seqlen_k, 513 *args, 514 **kwargs, 515) -> int: 516 # in case this is a nested tensor, we unpack the individual batch elements 517 # and then sum the flops per batch element 518 shapes = _unpack_efficient_attention_nested_shapes( 519 query=query, 520 key=key, 521 value=value, 522 grad_out=grad_out, 523 cu_seqlens_q=cu_seqlens_q, 524 cu_seqlens_k=cu_seqlens_k, 525 max_seqlen_q=max_seqlen_q, 526 max_seqlen_k=max_seqlen_k, 527 ) 528 return sum( 529 sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) 530 for query_shape, key_shape, value_shape, grad_out_shape in shapes 531 ) 532 533 534flop_registry = { 535 aten.mm: mm_flop, 536 aten.addmm: addmm_flop, 537 aten.bmm: bmm_flop, 538 aten.baddbmm: baddbmm_flop, 539 aten.convolution: conv_flop, 540 aten._convolution: conv_flop, 541 aten.convolution_backward: conv_backward_flop, 542 aten._scaled_dot_product_efficient_attention: sdpa_flop, 543 aten._scaled_dot_product_flash_attention: sdpa_flop, 544 aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop, 545 aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop, 546 aten._flash_attention_forward: _flash_attention_forward_flop, 547 aten._efficient_attention_forward: _efficient_attention_forward_flop, 548 aten._flash_attention_backward: _flash_attention_backward_flop, 549 aten._efficient_attention_backward: _efficient_attention_backward_flop, 550} 551 552def normalize_tuple(x): 553 if not isinstance(x, tuple): 554 return (x,) 555 return x 556 557 558# Define the suffixes for different orders of magnitude 559suffixes = ["", "K", "M", "B", "T"] 560# Thanks BingChat! 561def get_suffix_str(number): 562 # Find the index of the appropriate suffix based on the number of digits 563 # with some additional overflow. 564 # i.e. 1.01B should be displayed as 1001M, not 1.001B 565 index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3)) 566 return suffixes[index] 567 568def convert_num_with_suffix(number, suffix): 569 index = suffixes.index(suffix) 570 # Divide the number by 1000^index and format it to two decimal places 571 value = f"{number / 1000 ** index:.3f}" 572 # Return the value and the suffix as a string 573 return value + suffixes[index] 574 575def convert_to_percent_str(num, denom): 576 if denom == 0: 577 return "0%" 578 return f"{num / denom:.2%}" 579 580def _pytreeify_preserve_structure(f): 581 @wraps(f) 582 def nf(args): 583 flat_args, spec = tree_flatten(args) 584 out = f(*flat_args) 585 return tree_unflatten(out, spec) 586 587 return nf 588 589 590class FlopCounterMode(TorchDispatchMode): 591 """ 592 ``FlopCounterMode`` is a context manager that counts the number of flops within its context. 593 594 It does this using a ``TorchDispatchMode``. 595 596 It also supports hierarchical output by passing a module (or list of 597 modules) to FlopCounterMode on construction. If you do not need hierarchical 598 output, you do not need to use it with a module. 599 600 Example usage 601 602 .. code-block:: python 603 604 mod = ... 605 with FlopCounterMode(mod) as flop_counter: 606 mod.sum().backward() 607 608 """ 609 610 def __init__( 611 self, 612 mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None, 613 depth: int = 2, 614 display: bool = True, 615 custom_mapping: Optional[Dict[Any, Any]] = None): 616 super().__init__() 617 self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int)) 618 self.depth = depth 619 self.display = display 620 if custom_mapping is None: 621 custom_mapping = {} 622 if mods is not None: 623 warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2) 624 self.flop_registry = { 625 **flop_registry, 626 **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()} 627 } 628 self.mod_tracker = ModuleTracker() 629 630 def get_total_flops(self) -> int: 631 return sum(self.flop_counts['Global'].values()) 632 633 def get_flop_counts(self) -> Dict[str, Dict[Any, int]]: 634 """Return the flop counts as a dictionary of dictionaries. 635 636 The outer 637 dictionary is keyed by module name, and the inner dictionary is keyed by 638 operation name. 639 640 Returns: 641 Dict[str, Dict[Any, int]]: The flop counts as a dictionary. 642 """ 643 return {k: dict(v) for k, v in self.flop_counts.items()} 644 645 def get_table(self, depth=None): 646 if depth is None: 647 depth = self.depth 648 if depth is None: 649 depth = 999999 650 651 import tabulate 652 tabulate.PRESERVE_WHITESPACE = True 653 header = ["Module", "FLOP", "% Total"] 654 values = [] 655 global_flops = self.get_total_flops() 656 global_suffix = get_suffix_str(global_flops) 657 is_global_subsumed = False 658 659 def process_mod(mod_name, depth): 660 nonlocal is_global_subsumed 661 662 total_flops = sum(self.flop_counts[mod_name].values()) 663 664 is_global_subsumed |= total_flops >= global_flops 665 666 padding = " " * depth 667 values = [] 668 values.append([ 669 padding + mod_name, 670 convert_num_with_suffix(total_flops, global_suffix), 671 convert_to_percent_str(total_flops, global_flops) 672 ]) 673 for k, v in self.flop_counts[mod_name].items(): 674 values.append([ 675 padding + " - " + str(k), 676 convert_num_with_suffix(v, global_suffix), 677 convert_to_percent_str(v, global_flops) 678 ]) 679 return values 680 681 for mod in sorted(self.flop_counts.keys()): 682 if mod == 'Global': 683 continue 684 mod_depth = mod.count(".") + 1 685 if mod_depth > depth: 686 continue 687 688 cur_values = process_mod(mod, mod_depth - 1) 689 values.extend(cur_values) 690 691 # We do a bit of messing around here to only output the "Global" value 692 # if there are any FLOPs in there that aren't already fully contained by 693 # a module. 694 if 'Global' in self.flop_counts and not is_global_subsumed: 695 for idx in range(len(values)): 696 values[idx][0] = " " + values[idx][0] 697 698 values = process_mod('Global', 0) + values 699 700 if len(values) == 0: 701 values = [["Global", "0", "0%"]] 702 703 return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right")) 704 705 def __enter__(self): 706 self.flop_counts.clear() 707 self.mod_tracker.__enter__() 708 super().__enter__() 709 return self 710 711 def __exit__(self, *args): 712 super().__exit__(*args) 713 self.mod_tracker.__exit__() 714 if self.display: 715 print(self.get_table(self.depth)) 716 717 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 718 kwargs = kwargs if kwargs else {} 719 out = func(*args, **kwargs) 720 return self._count_flops(func._overloadpacket, out, args, kwargs) 721 722 def _count_flops(self, func_packet, out, args, kwargs): 723 if func_packet in self.flop_registry: 724 flop_count_func = self.flop_registry[func_packet] 725 flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] 726 for par in set(self.mod_tracker.parents): 727 self.flop_counts[par][func_packet] += flop_count 728 729 return out 730