xref: /aosp_15_r20/external/pytorch/test/test_flop_counter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import functools
4import unittest
5
6import torch
7import torch.nn.functional as F
8import torch.utils.flop_counter
9from torch._subclasses.fake_tensor import FakeTensorMode
10from torch.testing._internal.common_cuda import (
11    PLATFORM_SUPPORTS_FLASH_ATTENTION,
12    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
13)
14from torch.testing._internal.common_utils import (
15    run_tests,
16    TEST_WITH_TORCHDYNAMO,
17    TestCase,
18    skipIfRocm,
19)
20
21try:
22    from torchvision import models as torchvision_models
23
24    HAS_TORCHVISION = True
25except ImportError:
26    HAS_TORCHVISION = False
27skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
28
29HAS_CUDA = torch.cuda.is_available()
30
31
32def FlopCounterMode(*args, **kwargs):
33    return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
34
35
36def get_total_flops(mode):
37    return str(sum(v for _, v in mode.flop_counts["Global"].items()))
38
39
40def T(*shape, requires_grad=False):
41    return torch.randn(*shape, requires_grad=requires_grad)
42
43
44@unittest.skipIf(
45    TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now"
46)
47class TestFlopCounter(TestCase):
48    def test_flop_counter_variety(self):
49        mod = torch.nn.Linear(9, 10)
50        with FlopCounterMode() as mode:
51            torch.mm(T(4, 5), T(5, 6))
52            torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
53            torch.matmul(T(5, 6), T(6, 7))
54            torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
55            mod(T(8, 9))
56
57        self.assertExpectedInline(get_total_flops(mode), """3012""")
58
59    def test_op(self):
60        with FlopCounterMode() as mode:
61            torch.mm(T(4, 5), T(5, 6))
62        # 4 * 6 * 2 * 5 = 240
63        self.assertExpectedInline(get_total_flops(mode), """240""")
64
65        with mode:
66            torch.bmm(T(3, 4, 5), T(3, 5, 6))
67        # 3 * 4 * 6 * 2 * 5 = 720
68        self.assertExpectedInline(get_total_flops(mode), """720""")
69
70        with mode:
71            torch.addmm(T(4, 6), T(4, 5), T(5, 6))
72            torch.addmm(T(4, 1), T(4, 5), T(5, 6))
73            torch.addmm(T(6), T(4, 5), T(5, 6))
74
75        # 4 * 6 * 2 * 5 = 240
76        self.assertExpectedInline(get_total_flops(mode), """720""")
77
78        with mode:
79            torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
80
81        # 3 * 4 * 6 * 2 * 5 = 720
82        self.assertExpectedInline(get_total_flops(mode), """720""")
83
84        with mode:
85            torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
86
87        # out_image_size = 2 * 5 * 5
88        # kernel_size = 4 * 4
89        # c_out = 6
90        # c_in = 3
91        # out_image_size * kernel_size * c_out * 2 * c_in
92
93        # NB: I don't think this properly accounts for padding?
94        self.assertExpectedInline(get_total_flops(mode), """28800""")
95
96        with mode:
97            torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
98
99        # out_image_size = 2 * 5
100        # kernel_size = 4
101        # c_out = 6
102        # c_in = 3
103        # out_image_size * kernel_size * c_out * 2 * c_in
104
105        # NB: I don't think this properly accounts for padding?
106        self.assertExpectedInline(get_total_flops(mode), """1440""")
107
108    def test_backward(self):
109        with FlopCounterMode() as mode:
110            a = T(4, 5, requires_grad=True)
111            a = torch.mm(a, T(5, 6))
112            a = a.unsqueeze(0).expand(7, 4, 6)
113            a = torch.bmm(a, T(7, 6, 7))
114            a.sum().backward()
115
116        self.assertExpectedInline(get_total_flops(mode), """5184""")
117
118    def test_backward_reset(self):
119        with FlopCounterMode() as mode:
120            a = T(4, 5, requires_grad=True)
121            a.mm(a.t()).sum().backward()
122            a.mm(a.t()).sum().backward()
123
124        self.assertExpectedInline(get_total_flops(mode), """960""")
125
126    def test_torchscript(self):
127        def foo(x):
128            return torch.mm(x, x)
129
130        with FlopCounterMode() as mode:
131            foo(T(5, 5))
132        unscripted_flops = get_total_flops(mode)
133        ts_foo = torch.jit.script(foo)
134        with mode:
135            ts_foo(T(5, 5))
136        self.assertEqual(unscripted_flops, get_total_flops(mode))
137
138    def test_autograd_op(self):
139        class _CustomOp(torch.autograd.Function):
140            @staticmethod
141            def forward(ctx, input: torch.Tensor) -> torch.Tensor:
142                return torch.mm(input, input)
143
144            @staticmethod
145            def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
146                return torch.mm(grad_output, grad_output) + torch.mm(
147                    grad_output, grad_output
148                )
149
150        a = T(5, 5, requires_grad=True)
151        with FlopCounterMode() as mode:
152            a = _CustomOp.apply(a)
153            a.sum().backward()
154
155        self.assertExpectedInline(get_total_flops(mode), """750""")
156
157    def test_conv_backwards_as_decomposition(self):
158        # [conv backwards decomposition as conv forwards]
159
160        class onlyConvs(torch.autograd.Function):
161            @staticmethod
162            def forward(inp, weight, transposed):
163                if not transposed:
164                    return F.conv1d(inp, weight)
165                else:
166                    return F.conv_transpose1d(inp, weight)
167
168            @staticmethod
169            def setup_context(ctx, inputs, output):
170                inp, weight, transposed = inputs
171                ctx.save_for_backward(inp, weight)
172                ctx.transposed = transposed
173
174            @staticmethod
175            def backward(ctx, grad_out):
176                inp, weight = ctx.saved_tensors
177                if not ctx.transposed:
178                    grad_inp = F.conv_transpose1d(grad_out, weight)
179                    grad_weight = F.conv1d(inp, grad_out)
180                    return grad_inp, grad_weight, None
181                else:
182                    grad_inp = F.conv1d(grad_out, weight)
183                    grad_weight = F.conv1d(
184                        grad_out.transpose(1, 0), inp.transpose(1, 0)
185                    )
186                    return grad_inp, grad_weight.transpose(1, 0), None
187
188        from torch.func import grad
189
190        x = torch.randn(2, 3, 16, dtype=torch.float64)
191        weight = torch.randn(3, 4, 4, dtype=torch.float64)
192
193        def boring_conv(x, weight, transposed):
194            if not transposed:
195                return F.conv1d(x, weight).pow(2).sum()
196            else:
197                return F.conv_transpose1d(x, weight).pow(2).sum()
198
199        def only_convs(x, weight, transposed):
200            return onlyConvs.apply(x, weight, transposed).pow(2).sum()
201
202        boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
203        fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
204
205        self.assertEqual(boring_grads, fun_grads)
206
207    def test_convs(self):
208        def assert_equivalence(f, expected_forward=None):
209            with FlopCounterMode() as mode:
210                f()
211            conv_forward_flops = mode.get_flop_counts()["Global"][
212                torch.ops.aten.convolution
213            ]
214            conv_backward_flops = mode.get_flop_counts()["Global"][
215                torch.ops.aten.convolution_backward
216            ]
217
218            self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
219            if expected_forward is not None:
220                self.assertEqual(conv_forward_flops, expected_forward)
221
222        x = torch.rand(1, 1, 2, 2, requires_grad=True)
223        weight = torch.randn(1, 1, 2, 2, requires_grad=True)
224        assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
225
226        x = torch.rand(1, 1, 2, 2, requires_grad=True)
227        weight = torch.randn(1, 1, 1, 1, requires_grad=True)
228        assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
229
230        for in_channels, out_channels, groups in [
231            (1, 1, 1),
232            (1, 3, 1),
233            (3, 1, 1),
234            (3, 7, 1),
235            (2, 4, 2),
236            (4, 2, 2),
237        ]:
238            x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
239            weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
240            assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
241            transposed_weight = torch.randn(
242                in_channels, out_channels, 2, 2, requires_grad=True
243            )
244            assert_equivalence(
245                lambda: F.conv_transpose2d(x, transposed_weight).sum().backward()
246            )
247
248    @skipIfNoTorchVision
249    def test_module(self):
250        resnet18 = torchvision_models.resnet18()
251        with FlopCounterMode(resnet18) as mode:
252            a = T(1, 3, 224, 224, requires_grad=True)
253            resnet18(a).sum().backward()
254
255        self.assertExpectedInline(get_total_flops(mode), """10884440064""")
256        layer1_conv_flops = mode.flop_counts["ResNet.layer1"][
257            torch.ops.aten.convolution
258        ]
259        layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][
260            torch.ops.aten.convolution_backward
261        ]
262        self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
263        self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
264
265    def test_conv_transpose_loop(self):
266        x = torch.rand(1, 4, 30, 2)
267        model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
268
269        with FlopCounterMode() as mode:
270            for i in range(50):
271                out = model(x)
272                out.sum().backward()
273        self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
274
275    def test_custom(self):
276        mode = FlopCounterMode(
277            custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5}
278        )
279        with mode:
280            a = T(4, 5)
281            a + a
282
283        self.assertExpectedInline(get_total_flops(mode), """5""")
284
285        def count(*args, out_val):
286            return out_val.numel()
287
288        count._get_raw = True
289
290        mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
291        with mode:
292            a = T(4, 5)
293            a + a
294
295        self.assertExpectedInline(get_total_flops(mode), """20""")
296
297    def test_noop(self):
298        with FlopCounterMode() as mode:
299            T(4, 5).cos()
300
301    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
302    @unittest.skipIf(
303        not PLATFORM_SUPPORTS_FLASH_ATTENTION
304        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
305        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
306    )
307    def test_sdpa(self):
308        batch_size = 4
309        n_heads = 8
310        seq_len_q = 128
311        seq_len_k = 256
312        head_dim = 64
313        head_dim_v = 64
314        dtype = torch.float16
315
316        torch.manual_seed(0)
317
318        def get_flops(
319            batch_size,
320            n_heads,
321            seq_len_q,
322            seq_len_k,
323            head_dim,
324            head_dim_v,
325            dtype,
326            backend,
327            with_backward=False,
328        ):
329            query = torch.randn(
330                batch_size,
331                n_heads,
332                seq_len_q,
333                head_dim,
334                device="cuda",
335                dtype=dtype,
336                requires_grad=True,
337            )
338            key = torch.randn(
339                batch_size,
340                n_heads,
341                seq_len_k,
342                head_dim,
343                device="cuda",
344                dtype=dtype,
345                requires_grad=True,
346            )
347            value = torch.randn(
348                batch_size,
349                n_heads,
350                seq_len_k,
351                head_dim_v,
352                device="cuda",
353                dtype=dtype,
354                requires_grad=True,
355            )
356
357            if backend == "math":
358                backend = torch.backends.cuda.sdp_kernel(
359                    enable_flash=False, enable_math=True, enable_mem_efficient=False
360                )
361            elif backend == "flash":
362                backend = torch.backends.cuda.sdp_kernel(
363                    enable_flash=True, enable_math=False, enable_mem_efficient=False
364                )
365            elif backend == "mem_efficient":
366                backend = torch.backends.cuda.sdp_kernel(
367                    enable_flash=False, enable_math=False, enable_mem_efficient=True
368                )
369
370            mode = FlopCounterMode()
371            with backend, mode:
372                out = F.scaled_dot_product_attention(
373                    query, key, value, dropout_p=0, is_causal=True
374                )
375                if with_backward:
376                    out.sum().backward()
377            return int(get_total_flops(mode))
378
379        # Sets seq_len_q == seq_len_k and dim_q == dim_v
380        run_uniform_flops = functools.partial(
381            get_flops,
382            batch_size,
383            n_heads,
384            seq_len_q,
385            seq_len_q,
386            head_dim,
387            head_dim,
388            dtype,
389        )
390
391        flops = [
392            run_uniform_flops(backend, with_backward=False)
393            for backend in ["math", "flash", "mem_efficient"]
394        ]
395        flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
396        self.assertEqual(flops_fw_math, flops_fw_flash)
397        self.assertEqual(flops_fw_math, flops_fw_efficient)
398
399        self.assertExpectedInline(str(flops_fw_math), """134217728""")
400
401        flops = [
402            run_uniform_flops(backend, with_backward=True)
403            for backend in ["math", "flash", "mem_efficient"]
404        ]
405        flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
406        self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
407        self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
408        self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
409
410        run_nonuniform_flops = functools.partial(
411            get_flops,
412            batch_size,
413            n_heads,
414            seq_len_q,
415            seq_len_k,
416            head_dim,
417            head_dim_v,
418            dtype,
419        )
420        # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
421        non_uniform_backends = ["math", "mem_efficient"]
422        flops = [
423            run_nonuniform_flops(backend, with_backward=False)
424            for backend in non_uniform_backends
425        ]
426        flops_fw_math, flops_fw_efficient = flops
427        self.assertEqual(flops_fw_math, flops_fw_efficient)
428
429        self.assertExpectedInline(str(flops_fw_math), """268435456""")
430
431        flops = [
432            run_nonuniform_flops(backend, with_backward=True)
433            for backend in non_uniform_backends
434        ]
435        flops_fw_bw_math, flops_fw_bw_efficient = flops
436        self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
437        self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
438
439    @skipIfRocm  # Nested tensor
440    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
441    @unittest.skipIf(
442        not PLATFORM_SUPPORTS_FLASH_ATTENTION
443        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
444        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
445    )
446    def test_sdpa_nested_tensor(self):
447        def get_flops(q, k, v, backend, with_backward=False):
448            mode = FlopCounterMode()
449
450            if backend == "math":
451                backend = torch.backends.cuda.sdp_kernel(
452                    enable_flash=False, enable_math=True, enable_mem_efficient=False
453                )
454            elif backend == "flash":
455                backend = torch.backends.cuda.sdp_kernel(
456                    enable_flash=True, enable_math=False, enable_mem_efficient=False
457                )
458            elif backend == "mem_efficient":
459                backend = torch.backends.cuda.sdp_kernel(
460                    enable_flash=False, enable_math=False, enable_mem_efficient=True
461                )
462
463            with backend, mode:
464                out = F.scaled_dot_product_attention(
465                    q, k, v, dropout_p=0, is_causal=True
466                )
467                if with_backward:
468                    if out.is_nested:
469                        out.values().sum().backward()
470                    else:
471                        out.sum().backward()
472
473            return int(get_total_flops(mode))
474
475        def get_nested_inputs(
476            batch_size,
477            n_heads,
478            max_seq_len_q,
479            max_seq_len_k,
480            head_dim,
481            head_dim_v,
482            dtype,
483        ):
484            q_lengths = torch.tensor(
485                [
486                    max_seq_len_q // 4,
487                    max_seq_len_q // 4 * 2,
488                    max_seq_len_q // 4 * 3,
489                    max_seq_len_q // 4 * 4,
490                ]
491            )
492            k_lengths = torch.tensor(
493                [
494                    max_seq_len_k // 4,
495                    max_seq_len_k // 4 * 2,
496                    max_seq_len_k // 4 * 3,
497                    max_seq_len_k // 4 * 4,
498                ]
499            )
500            q_offsets, k_offsets = (
501                torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda()
502                for lengths in (q_lengths, k_lengths)
503            )
504            q_values = torch.randn(
505                q_offsets[-1],
506                head_dim * n_heads,
507                dtype=dtype,
508                requires_grad=True,
509                device="cuda",
510            )
511            k_values = torch.randn(
512                k_offsets[-1],
513                head_dim * n_heads,
514                dtype=dtype,
515                requires_grad=True,
516                device="cuda",
517            )
518            v_values = torch.randn(
519                k_offsets[-1],
520                head_dim_v * n_heads,
521                dtype=dtype,
522                requires_grad=True,
523                device="cuda",
524            )
525
526            q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets)
527            k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets)
528            v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets)
529
530            q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
531            k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
532            v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2)
533
534            return q, k, v
535
536        def get_dense_flops(q, k, v, backend, with_backward=False):
537            def split_tensor(x):
538                return (
539                    y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True)
540                    for y in x.transpose(1, 2).unbind(0)
541                )
542
543            q_tensors = split_tensor(q)
544            k_tensors = split_tensor(k)
545            v_tensors = split_tensor(v)
546
547            flops = 0
548            for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors):
549                flops += get_flops(
550                    q_i, k_i, v_i, backend=backend, with_backward=with_backward
551                )
552
553            return flops
554
555        uniform_config = {
556            "batch_size": 4,
557            "n_heads": 8,
558            "max_seq_len_q": 128,
559            "max_seq_len_k": 128,
560            "head_dim": 64,
561            "head_dim_v": 64,
562            "dtype": torch.float16,
563        }
564
565        # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors.
566        differing_config = {
567            "batch_size": 4,
568            "n_heads": 8,
569            "max_seq_len_q": 128,
570            "max_seq_len_k": 256,
571            "head_dim": 64,
572            "head_dim_v": 64,
573            "dtype": torch.float16,
574        }
575
576        self.assertEqual(
577            get_dense_flops(
578                *get_nested_inputs(**uniform_config),
579                backend="flash",
580                with_backward=False,
581            ),
582            get_flops(
583                *get_nested_inputs(**uniform_config),
584                backend="flash",
585                with_backward=False,
586            ),
587        )
588        self.assertEqual(
589            get_dense_flops(
590                *get_nested_inputs(**uniform_config),
591                backend="mem_efficient",
592                with_backward=False,
593            ),
594            get_flops(
595                *get_nested_inputs(**uniform_config),
596                backend="mem_efficient",
597                with_backward=False,
598            ),
599        )
600        self.assertEqual(
601            get_dense_flops(
602                *get_nested_inputs(**differing_config),
603                backend="mem_efficient",
604                with_backward=False,
605            ),
606            get_flops(
607                *get_nested_inputs(**differing_config),
608                backend="mem_efficient",
609                with_backward=False,
610            ),
611        )
612
613        self.assertEqual(
614            get_dense_flops(
615                *get_nested_inputs(**uniform_config),
616                backend="flash",
617                with_backward=True,
618            ),
619            get_flops(
620                *get_nested_inputs(**uniform_config),
621                backend="flash",
622                with_backward=True,
623            ),
624        )
625        self.assertEqual(
626            get_dense_flops(
627                *get_nested_inputs(**uniform_config),
628                backend="mem_efficient",
629                with_backward=True,
630            ),
631            get_flops(
632                *get_nested_inputs(**uniform_config),
633                backend="mem_efficient",
634                with_backward=True,
635            ),
636        )
637        self.assertEqual(
638            get_dense_flops(
639                *get_nested_inputs(**differing_config),
640                backend="mem_efficient",
641                with_backward=True,
642            ),
643            get_flops(
644                *get_nested_inputs(**differing_config),
645                backend="mem_efficient",
646                with_backward=True,
647            ),
648        )
649
650    @skipIfRocm  # Nested tensor
651    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
652    @unittest.skipIf(
653        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
654        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
655    )
656    def test_nested_attention_fake_tensors(self):
657        x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
658        offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
659        max_seqlen = 40
660        with FakeTensorMode() as fake_mode:
661            fake_x = fake_mode.from_tensor(x)
662            fake_offsets = fake_mode.from_tensor(offsets)
663
664            with FlopCounterMode() as fake_flop_counter_mode:
665                torch.ops.aten._flash_attention_forward(
666                    fake_x,
667                    fake_x,
668                    fake_x,
669                    fake_offsets,
670                    fake_offsets,
671                    max_seqlen,
672                    max_seqlen,
673                    0.0,
674                    False,
675                    False,
676                )
677
678        dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
679
680        with FlopCounterMode() as real_flop_counter_mode:
681            torch.ops.aten._flash_attention_forward(
682                dense_x,
683                dense_x,
684                dense_x,
685                None,
686                None,
687                max_seqlen,
688                max_seqlen,
689                0.0,
690                False,
691                False,
692            )
693
694        self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
695
696
697    def test_addmm_out(self):
698        def f(x):
699            y = torch.zeros(10, 10)
700            return torch.mm(x, x, out=y)
701
702        with FlopCounterMode() as mode:
703            f(torch.randn(10, 10))
704
705        self.assertExpectedInline(get_total_flops(mode), """2000""")
706
707    def test_hook_registration(self):
708        model = torch.nn.Linear(100, 100)
709        x = torch.randn(3, 100)
710
711        with FlopCounterMode() as mode:
712            self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1)
713            self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1)
714            model(x).sum().backward()
715
716        self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0)
717        self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0)
718
719    def test_pytrees(self):
720        class Foo(torch.nn.Module):
721            def forward(self, x):
722                x = x["a"].relu_()
723                return {"a": torch.mm(x, x)}
724
725        class Mod(torch.nn.Module):
726            def __init__(self) -> None:
727                super().__init__()
728                self.a = Foo()
729                self.b = Foo()
730
731            def forward(self, x):
732                return self.b(self.a(x))
733
734        mod = Mod()
735        with FlopCounterMode() as mode:
736            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
737                "a"
738            ].sum().backward()
739        self.assertExpectedInline(
740            (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000"""
741        )
742
743        class Mod2(torch.nn.Module):
744            def forward(self, x):
745                return (torch.mm(x, x),)
746
747        mod = Mod2()
748        with FlopCounterMode() as mode:
749            mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
750        self.assertExpectedInline(
751            (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000"""
752        )
753
754    def test_warning(self):
755        mod = torch.nn.Linear(2, 2)
756        with self.assertWarnsRegex(UserWarning, "not needed"):
757            FlopCounterMode(mod)
758
759    def test_custom_op(self):
760        from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
761
762        @torch.library.custom_op("mylib::foo", mutates_args=())
763        def foo(x: torch.Tensor) -> torch.Tensor:
764            return x.sin()
765
766        called = 0
767
768        with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
769            register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)
770
771        @register_flop_formula(torch.ops.mylib.foo)
772        def formula(*args, **kwargs):
773            nonlocal called
774            called += 1
775            return 9001
776
777        x = torch.randn(3)
778        with FlopCounterMode(display=False) as mode:
779            y = foo(x)
780
781        self.assertEqual(called, 1)
782        self.assertExpectedInline(get_total_flops(mode), """9001""")
783
784
785if __name__ == "__main__":
786    run_tests()
787