1# Owner(s): ["module: unknown"] 2 3import functools 4import unittest 5 6import torch 7import torch.nn.functional as F 8import torch.utils.flop_counter 9from torch._subclasses.fake_tensor import FakeTensorMode 10from torch.testing._internal.common_cuda import ( 11 PLATFORM_SUPPORTS_FLASH_ATTENTION, 12 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 13) 14from torch.testing._internal.common_utils import ( 15 run_tests, 16 TEST_WITH_TORCHDYNAMO, 17 TestCase, 18 skipIfRocm, 19) 20 21try: 22 from torchvision import models as torchvision_models 23 24 HAS_TORCHVISION = True 25except ImportError: 26 HAS_TORCHVISION = False 27skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 28 29HAS_CUDA = torch.cuda.is_available() 30 31 32def FlopCounterMode(*args, **kwargs): 33 return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False) 34 35 36def get_total_flops(mode): 37 return str(sum(v for _, v in mode.flop_counts["Global"].items())) 38 39 40def T(*shape, requires_grad=False): 41 return torch.randn(*shape, requires_grad=requires_grad) 42 43 44@unittest.skipIf( 45 TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now" 46) 47class TestFlopCounter(TestCase): 48 def test_flop_counter_variety(self): 49 mod = torch.nn.Linear(9, 10) 50 with FlopCounterMode() as mode: 51 torch.mm(T(4, 5), T(5, 6)) 52 torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5) 53 torch.matmul(T(5, 6), T(6, 7)) 54 torch.einsum("ab,bc->ac", T(6, 7), T(7, 8)) 55 mod(T(8, 9)) 56 57 self.assertExpectedInline(get_total_flops(mode), """3012""") 58 59 def test_op(self): 60 with FlopCounterMode() as mode: 61 torch.mm(T(4, 5), T(5, 6)) 62 # 4 * 6 * 2 * 5 = 240 63 self.assertExpectedInline(get_total_flops(mode), """240""") 64 65 with mode: 66 torch.bmm(T(3, 4, 5), T(3, 5, 6)) 67 # 3 * 4 * 6 * 2 * 5 = 720 68 self.assertExpectedInline(get_total_flops(mode), """720""") 69 70 with mode: 71 torch.addmm(T(4, 6), T(4, 5), T(5, 6)) 72 torch.addmm(T(4, 1), T(4, 5), T(5, 6)) 73 torch.addmm(T(6), T(4, 5), T(5, 6)) 74 75 # 4 * 6 * 2 * 5 = 240 76 self.assertExpectedInline(get_total_flops(mode), """720""") 77 78 with mode: 79 torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6)) 80 81 # 3 * 4 * 6 * 2 * 5 = 720 82 self.assertExpectedInline(get_total_flops(mode), """720""") 83 84 with mode: 85 torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1) 86 87 # out_image_size = 2 * 5 * 5 88 # kernel_size = 4 * 4 89 # c_out = 6 90 # c_in = 3 91 # out_image_size * kernel_size * c_out * 2 * c_in 92 93 # NB: I don't think this properly accounts for padding? 94 self.assertExpectedInline(get_total_flops(mode), """28800""") 95 96 with mode: 97 torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1) 98 99 # out_image_size = 2 * 5 100 # kernel_size = 4 101 # c_out = 6 102 # c_in = 3 103 # out_image_size * kernel_size * c_out * 2 * c_in 104 105 # NB: I don't think this properly accounts for padding? 106 self.assertExpectedInline(get_total_flops(mode), """1440""") 107 108 def test_backward(self): 109 with FlopCounterMode() as mode: 110 a = T(4, 5, requires_grad=True) 111 a = torch.mm(a, T(5, 6)) 112 a = a.unsqueeze(0).expand(7, 4, 6) 113 a = torch.bmm(a, T(7, 6, 7)) 114 a.sum().backward() 115 116 self.assertExpectedInline(get_total_flops(mode), """5184""") 117 118 def test_backward_reset(self): 119 with FlopCounterMode() as mode: 120 a = T(4, 5, requires_grad=True) 121 a.mm(a.t()).sum().backward() 122 a.mm(a.t()).sum().backward() 123 124 self.assertExpectedInline(get_total_flops(mode), """960""") 125 126 def test_torchscript(self): 127 def foo(x): 128 return torch.mm(x, x) 129 130 with FlopCounterMode() as mode: 131 foo(T(5, 5)) 132 unscripted_flops = get_total_flops(mode) 133 ts_foo = torch.jit.script(foo) 134 with mode: 135 ts_foo(T(5, 5)) 136 self.assertEqual(unscripted_flops, get_total_flops(mode)) 137 138 def test_autograd_op(self): 139 class _CustomOp(torch.autograd.Function): 140 @staticmethod 141 def forward(ctx, input: torch.Tensor) -> torch.Tensor: 142 return torch.mm(input, input) 143 144 @staticmethod 145 def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: 146 return torch.mm(grad_output, grad_output) + torch.mm( 147 grad_output, grad_output 148 ) 149 150 a = T(5, 5, requires_grad=True) 151 with FlopCounterMode() as mode: 152 a = _CustomOp.apply(a) 153 a.sum().backward() 154 155 self.assertExpectedInline(get_total_flops(mode), """750""") 156 157 def test_conv_backwards_as_decomposition(self): 158 # [conv backwards decomposition as conv forwards] 159 160 class onlyConvs(torch.autograd.Function): 161 @staticmethod 162 def forward(inp, weight, transposed): 163 if not transposed: 164 return F.conv1d(inp, weight) 165 else: 166 return F.conv_transpose1d(inp, weight) 167 168 @staticmethod 169 def setup_context(ctx, inputs, output): 170 inp, weight, transposed = inputs 171 ctx.save_for_backward(inp, weight) 172 ctx.transposed = transposed 173 174 @staticmethod 175 def backward(ctx, grad_out): 176 inp, weight = ctx.saved_tensors 177 if not ctx.transposed: 178 grad_inp = F.conv_transpose1d(grad_out, weight) 179 grad_weight = F.conv1d(inp, grad_out) 180 return grad_inp, grad_weight, None 181 else: 182 grad_inp = F.conv1d(grad_out, weight) 183 grad_weight = F.conv1d( 184 grad_out.transpose(1, 0), inp.transpose(1, 0) 185 ) 186 return grad_inp, grad_weight.transpose(1, 0), None 187 188 from torch.func import grad 189 190 x = torch.randn(2, 3, 16, dtype=torch.float64) 191 weight = torch.randn(3, 4, 4, dtype=torch.float64) 192 193 def boring_conv(x, weight, transposed): 194 if not transposed: 195 return F.conv1d(x, weight).pow(2).sum() 196 else: 197 return F.conv_transpose1d(x, weight).pow(2).sum() 198 199 def only_convs(x, weight, transposed): 200 return onlyConvs.apply(x, weight, transposed).pow(2).sum() 201 202 boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True) 203 fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True) 204 205 self.assertEqual(boring_grads, fun_grads) 206 207 def test_convs(self): 208 def assert_equivalence(f, expected_forward=None): 209 with FlopCounterMode() as mode: 210 f() 211 conv_forward_flops = mode.get_flop_counts()["Global"][ 212 torch.ops.aten.convolution 213 ] 214 conv_backward_flops = mode.get_flop_counts()["Global"][ 215 torch.ops.aten.convolution_backward 216 ] 217 218 self.assertEqual(conv_forward_flops * 2, conv_backward_flops) 219 if expected_forward is not None: 220 self.assertEqual(conv_forward_flops, expected_forward) 221 222 x = torch.rand(1, 1, 2, 2, requires_grad=True) 223 weight = torch.randn(1, 1, 2, 2, requires_grad=True) 224 assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32) 225 226 x = torch.rand(1, 1, 2, 2, requires_grad=True) 227 weight = torch.randn(1, 1, 1, 1, requires_grad=True) 228 assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8) 229 230 for in_channels, out_channels, groups in [ 231 (1, 1, 1), 232 (1, 3, 1), 233 (3, 1, 1), 234 (3, 7, 1), 235 (2, 4, 2), 236 (4, 2, 2), 237 ]: 238 x = torch.rand(1, in_channels, 4, 4, requires_grad=True) 239 weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True) 240 assert_equivalence(lambda: F.conv2d(x, weight).sum().backward()) 241 transposed_weight = torch.randn( 242 in_channels, out_channels, 2, 2, requires_grad=True 243 ) 244 assert_equivalence( 245 lambda: F.conv_transpose2d(x, transposed_weight).sum().backward() 246 ) 247 248 @skipIfNoTorchVision 249 def test_module(self): 250 resnet18 = torchvision_models.resnet18() 251 with FlopCounterMode(resnet18) as mode: 252 a = T(1, 3, 224, 224, requires_grad=True) 253 resnet18(a).sum().backward() 254 255 self.assertExpectedInline(get_total_flops(mode), """10884440064""") 256 layer1_conv_flops = mode.flop_counts["ResNet.layer1"][ 257 torch.ops.aten.convolution 258 ] 259 layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][ 260 torch.ops.aten.convolution_backward 261 ] 262 self.assertExpectedInline(str(layer1_conv_flops), """924844032""") 263 self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""") 264 265 def test_conv_transpose_loop(self): 266 x = torch.rand(1, 4, 30, 2) 267 model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2) 268 269 with FlopCounterMode() as mode: 270 for i in range(50): 271 out = model(x) 272 out.sum().backward() 273 self.assertExpectedInline(str(mode.get_total_flops()), """1536000""") 274 275 def test_custom(self): 276 mode = FlopCounterMode( 277 custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5} 278 ) 279 with mode: 280 a = T(4, 5) 281 a + a 282 283 self.assertExpectedInline(get_total_flops(mode), """5""") 284 285 def count(*args, out_val): 286 return out_val.numel() 287 288 count._get_raw = True 289 290 mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count}) 291 with mode: 292 a = T(4, 5) 293 a + a 294 295 self.assertExpectedInline(get_total_flops(mode), """20""") 296 297 def test_noop(self): 298 with FlopCounterMode() as mode: 299 T(4, 5).cos() 300 301 @unittest.skipIf(not HAS_CUDA, "CUDA not available") 302 @unittest.skipIf( 303 not PLATFORM_SUPPORTS_FLASH_ATTENTION 304 or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 305 "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", 306 ) 307 def test_sdpa(self): 308 batch_size = 4 309 n_heads = 8 310 seq_len_q = 128 311 seq_len_k = 256 312 head_dim = 64 313 head_dim_v = 64 314 dtype = torch.float16 315 316 torch.manual_seed(0) 317 318 def get_flops( 319 batch_size, 320 n_heads, 321 seq_len_q, 322 seq_len_k, 323 head_dim, 324 head_dim_v, 325 dtype, 326 backend, 327 with_backward=False, 328 ): 329 query = torch.randn( 330 batch_size, 331 n_heads, 332 seq_len_q, 333 head_dim, 334 device="cuda", 335 dtype=dtype, 336 requires_grad=True, 337 ) 338 key = torch.randn( 339 batch_size, 340 n_heads, 341 seq_len_k, 342 head_dim, 343 device="cuda", 344 dtype=dtype, 345 requires_grad=True, 346 ) 347 value = torch.randn( 348 batch_size, 349 n_heads, 350 seq_len_k, 351 head_dim_v, 352 device="cuda", 353 dtype=dtype, 354 requires_grad=True, 355 ) 356 357 if backend == "math": 358 backend = torch.backends.cuda.sdp_kernel( 359 enable_flash=False, enable_math=True, enable_mem_efficient=False 360 ) 361 elif backend == "flash": 362 backend = torch.backends.cuda.sdp_kernel( 363 enable_flash=True, enable_math=False, enable_mem_efficient=False 364 ) 365 elif backend == "mem_efficient": 366 backend = torch.backends.cuda.sdp_kernel( 367 enable_flash=False, enable_math=False, enable_mem_efficient=True 368 ) 369 370 mode = FlopCounterMode() 371 with backend, mode: 372 out = F.scaled_dot_product_attention( 373 query, key, value, dropout_p=0, is_causal=True 374 ) 375 if with_backward: 376 out.sum().backward() 377 return int(get_total_flops(mode)) 378 379 # Sets seq_len_q == seq_len_k and dim_q == dim_v 380 run_uniform_flops = functools.partial( 381 get_flops, 382 batch_size, 383 n_heads, 384 seq_len_q, 385 seq_len_q, 386 head_dim, 387 head_dim, 388 dtype, 389 ) 390 391 flops = [ 392 run_uniform_flops(backend, with_backward=False) 393 for backend in ["math", "flash", "mem_efficient"] 394 ] 395 flops_fw_math, flops_fw_flash, flops_fw_efficient = flops 396 self.assertEqual(flops_fw_math, flops_fw_flash) 397 self.assertEqual(flops_fw_math, flops_fw_efficient) 398 399 self.assertExpectedInline(str(flops_fw_math), """134217728""") 400 401 flops = [ 402 run_uniform_flops(backend, with_backward=True) 403 for backend in ["math", "flash", "mem_efficient"] 404 ] 405 flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops 406 self.assertEqual(flops_fw_math * 3, flops_fw_bw_math) 407 self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash) 408 self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient) 409 410 run_nonuniform_flops = functools.partial( 411 get_flops, 412 batch_size, 413 n_heads, 414 seq_len_q, 415 seq_len_k, 416 head_dim, 417 head_dim_v, 418 dtype, 419 ) 420 # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v" 421 non_uniform_backends = ["math", "mem_efficient"] 422 flops = [ 423 run_nonuniform_flops(backend, with_backward=False) 424 for backend in non_uniform_backends 425 ] 426 flops_fw_math, flops_fw_efficient = flops 427 self.assertEqual(flops_fw_math, flops_fw_efficient) 428 429 self.assertExpectedInline(str(flops_fw_math), """268435456""") 430 431 flops = [ 432 run_nonuniform_flops(backend, with_backward=True) 433 for backend in non_uniform_backends 434 ] 435 flops_fw_bw_math, flops_fw_bw_efficient = flops 436 self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") 437 self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") 438 439 @skipIfRocm # Nested tensor 440 @unittest.skipIf(not HAS_CUDA, "CUDA not available") 441 @unittest.skipIf( 442 not PLATFORM_SUPPORTS_FLASH_ATTENTION 443 or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 444 "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", 445 ) 446 def test_sdpa_nested_tensor(self): 447 def get_flops(q, k, v, backend, with_backward=False): 448 mode = FlopCounterMode() 449 450 if backend == "math": 451 backend = torch.backends.cuda.sdp_kernel( 452 enable_flash=False, enable_math=True, enable_mem_efficient=False 453 ) 454 elif backend == "flash": 455 backend = torch.backends.cuda.sdp_kernel( 456 enable_flash=True, enable_math=False, enable_mem_efficient=False 457 ) 458 elif backend == "mem_efficient": 459 backend = torch.backends.cuda.sdp_kernel( 460 enable_flash=False, enable_math=False, enable_mem_efficient=True 461 ) 462 463 with backend, mode: 464 out = F.scaled_dot_product_attention( 465 q, k, v, dropout_p=0, is_causal=True 466 ) 467 if with_backward: 468 if out.is_nested: 469 out.values().sum().backward() 470 else: 471 out.sum().backward() 472 473 return int(get_total_flops(mode)) 474 475 def get_nested_inputs( 476 batch_size, 477 n_heads, 478 max_seq_len_q, 479 max_seq_len_k, 480 head_dim, 481 head_dim_v, 482 dtype, 483 ): 484 q_lengths = torch.tensor( 485 [ 486 max_seq_len_q // 4, 487 max_seq_len_q // 4 * 2, 488 max_seq_len_q // 4 * 3, 489 max_seq_len_q // 4 * 4, 490 ] 491 ) 492 k_lengths = torch.tensor( 493 [ 494 max_seq_len_k // 4, 495 max_seq_len_k // 4 * 2, 496 max_seq_len_k // 4 * 3, 497 max_seq_len_k // 4 * 4, 498 ] 499 ) 500 q_offsets, k_offsets = ( 501 torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda() 502 for lengths in (q_lengths, k_lengths) 503 ) 504 q_values = torch.randn( 505 q_offsets[-1], 506 head_dim * n_heads, 507 dtype=dtype, 508 requires_grad=True, 509 device="cuda", 510 ) 511 k_values = torch.randn( 512 k_offsets[-1], 513 head_dim * n_heads, 514 dtype=dtype, 515 requires_grad=True, 516 device="cuda", 517 ) 518 v_values = torch.randn( 519 k_offsets[-1], 520 head_dim_v * n_heads, 521 dtype=dtype, 522 requires_grad=True, 523 device="cuda", 524 ) 525 526 q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets) 527 k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets) 528 v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets) 529 530 q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2) 531 k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2) 532 v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2) 533 534 return q, k, v 535 536 def get_dense_flops(q, k, v, backend, with_backward=False): 537 def split_tensor(x): 538 return ( 539 y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True) 540 for y in x.transpose(1, 2).unbind(0) 541 ) 542 543 q_tensors = split_tensor(q) 544 k_tensors = split_tensor(k) 545 v_tensors = split_tensor(v) 546 547 flops = 0 548 for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors): 549 flops += get_flops( 550 q_i, k_i, v_i, backend=backend, with_backward=with_backward 551 ) 552 553 return flops 554 555 uniform_config = { 556 "batch_size": 4, 557 "n_heads": 8, 558 "max_seq_len_q": 128, 559 "max_seq_len_k": 128, 560 "head_dim": 64, 561 "head_dim_v": 64, 562 "dtype": torch.float16, 563 } 564 565 # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors. 566 differing_config = { 567 "batch_size": 4, 568 "n_heads": 8, 569 "max_seq_len_q": 128, 570 "max_seq_len_k": 256, 571 "head_dim": 64, 572 "head_dim_v": 64, 573 "dtype": torch.float16, 574 } 575 576 self.assertEqual( 577 get_dense_flops( 578 *get_nested_inputs(**uniform_config), 579 backend="flash", 580 with_backward=False, 581 ), 582 get_flops( 583 *get_nested_inputs(**uniform_config), 584 backend="flash", 585 with_backward=False, 586 ), 587 ) 588 self.assertEqual( 589 get_dense_flops( 590 *get_nested_inputs(**uniform_config), 591 backend="mem_efficient", 592 with_backward=False, 593 ), 594 get_flops( 595 *get_nested_inputs(**uniform_config), 596 backend="mem_efficient", 597 with_backward=False, 598 ), 599 ) 600 self.assertEqual( 601 get_dense_flops( 602 *get_nested_inputs(**differing_config), 603 backend="mem_efficient", 604 with_backward=False, 605 ), 606 get_flops( 607 *get_nested_inputs(**differing_config), 608 backend="mem_efficient", 609 with_backward=False, 610 ), 611 ) 612 613 self.assertEqual( 614 get_dense_flops( 615 *get_nested_inputs(**uniform_config), 616 backend="flash", 617 with_backward=True, 618 ), 619 get_flops( 620 *get_nested_inputs(**uniform_config), 621 backend="flash", 622 with_backward=True, 623 ), 624 ) 625 self.assertEqual( 626 get_dense_flops( 627 *get_nested_inputs(**uniform_config), 628 backend="mem_efficient", 629 with_backward=True, 630 ), 631 get_flops( 632 *get_nested_inputs(**uniform_config), 633 backend="mem_efficient", 634 with_backward=True, 635 ), 636 ) 637 self.assertEqual( 638 get_dense_flops( 639 *get_nested_inputs(**differing_config), 640 backend="mem_efficient", 641 with_backward=True, 642 ), 643 get_flops( 644 *get_nested_inputs(**differing_config), 645 backend="mem_efficient", 646 with_backward=True, 647 ), 648 ) 649 650 @skipIfRocm # Nested tensor 651 @unittest.skipIf(not HAS_CUDA, "CUDA not available") 652 @unittest.skipIf( 653 not PLATFORM_SUPPORTS_FLASH_ATTENTION, 654 "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", 655 ) 656 def test_nested_attention_fake_tensors(self): 657 x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16) 658 offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda") 659 max_seqlen = 40 660 with FakeTensorMode() as fake_mode: 661 fake_x = fake_mode.from_tensor(x) 662 fake_offsets = fake_mode.from_tensor(offsets) 663 664 with FlopCounterMode() as fake_flop_counter_mode: 665 torch.ops.aten._flash_attention_forward( 666 fake_x, 667 fake_x, 668 fake_x, 669 fake_offsets, 670 fake_offsets, 671 max_seqlen, 672 max_seqlen, 673 0.0, 674 False, 675 False, 676 ) 677 678 dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2) 679 680 with FlopCounterMode() as real_flop_counter_mode: 681 torch.ops.aten._flash_attention_forward( 682 dense_x, 683 dense_x, 684 dense_x, 685 None, 686 None, 687 max_seqlen, 688 max_seqlen, 689 0.0, 690 False, 691 False, 692 ) 693 694 self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode))) 695 696 697 def test_addmm_out(self): 698 def f(x): 699 y = torch.zeros(10, 10) 700 return torch.mm(x, x, out=y) 701 702 with FlopCounterMode() as mode: 703 f(torch.randn(10, 10)) 704 705 self.assertExpectedInline(get_total_flops(mode), """2000""") 706 707 def test_hook_registration(self): 708 model = torch.nn.Linear(100, 100) 709 x = torch.randn(3, 100) 710 711 with FlopCounterMode() as mode: 712 self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1) 713 self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1) 714 model(x).sum().backward() 715 716 self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0) 717 self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0) 718 719 def test_pytrees(self): 720 class Foo(torch.nn.Module): 721 def forward(self, x): 722 x = x["a"].relu_() 723 return {"a": torch.mm(x, x)} 724 725 class Mod(torch.nn.Module): 726 def __init__(self) -> None: 727 super().__init__() 728 self.a = Foo() 729 self.b = Foo() 730 731 def forward(self, x): 732 return self.b(self.a(x)) 733 734 mod = Mod() 735 with FlopCounterMode() as mode: 736 mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 737 "a" 738 ].sum().backward() 739 self.assertExpectedInline( 740 (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000""" 741 ) 742 743 class Mod2(torch.nn.Module): 744 def forward(self, x): 745 return (torch.mm(x, x),) 746 747 mod = Mod2() 748 with FlopCounterMode() as mode: 749 mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward() 750 self.assertExpectedInline( 751 (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000""" 752 ) 753 754 def test_warning(self): 755 mod = torch.nn.Linear(2, 2) 756 with self.assertWarnsRegex(UserWarning, "not needed"): 757 FlopCounterMode(mod) 758 759 def test_custom_op(self): 760 from torch.utils.flop_counter import FlopCounterMode, register_flop_formula 761 762 @torch.library.custom_op("mylib::foo", mutates_args=()) 763 def foo(x: torch.Tensor) -> torch.Tensor: 764 return x.sin() 765 766 called = 0 767 768 with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"): 769 register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x) 770 771 @register_flop_formula(torch.ops.mylib.foo) 772 def formula(*args, **kwargs): 773 nonlocal called 774 called += 1 775 return 9001 776 777 x = torch.randn(3) 778 with FlopCounterMode(display=False) as mode: 779 y = foo(x) 780 781 self.assertEqual(called, 1) 782 self.assertExpectedInline(get_total_flops(mode), """9001""") 783 784 785if __name__ == "__main__": 786 run_tests() 787