xref: /aosp_15_r20/external/pytorch/test/test_flop_counter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.flop_counter
9*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensorMode
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
11*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_FLASH_ATTENTION,
12*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
13*da0073e9SAndroid Build Coastguard Worker)
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
15*da0073e9SAndroid Build Coastguard Worker    run_tests,
16*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
17*da0073e9SAndroid Build Coastguard Worker    TestCase,
18*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workertry:
22*da0073e9SAndroid Build Coastguard Worker    from torchvision import models as torchvision_models
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
25*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
26*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
27*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard WorkerHAS_CUDA = torch.cuda.is_available()
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdef FlopCounterMode(*args, **kwargs):
33*da0073e9SAndroid Build Coastguard Worker    return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerdef get_total_flops(mode):
37*da0073e9SAndroid Build Coastguard Worker    return str(sum(v for _, v in mode.flop_counts["Global"].items()))
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerdef T(*shape, requires_grad=False):
41*da0073e9SAndroid Build Coastguard Worker    return torch.randn(*shape, requires_grad=requires_grad)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
45*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now"
46*da0073e9SAndroid Build Coastguard Worker)
47*da0073e9SAndroid Build Coastguard Workerclass TestFlopCounter(TestCase):
48*da0073e9SAndroid Build Coastguard Worker    def test_flop_counter_variety(self):
49*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Linear(9, 10)
50*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
51*da0073e9SAndroid Build Coastguard Worker            torch.mm(T(4, 5), T(5, 6))
52*da0073e9SAndroid Build Coastguard Worker            torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
53*da0073e9SAndroid Build Coastguard Worker            torch.matmul(T(5, 6), T(6, 7))
54*da0073e9SAndroid Build Coastguard Worker            torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
55*da0073e9SAndroid Build Coastguard Worker            mod(T(8, 9))
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """3012""")
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def test_op(self):
60*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
61*da0073e9SAndroid Build Coastguard Worker            torch.mm(T(4, 5), T(5, 6))
62*da0073e9SAndroid Build Coastguard Worker        # 4 * 6 * 2 * 5 = 240
63*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """240""")
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        with mode:
66*da0073e9SAndroid Build Coastguard Worker            torch.bmm(T(3, 4, 5), T(3, 5, 6))
67*da0073e9SAndroid Build Coastguard Worker        # 3 * 4 * 6 * 2 * 5 = 720
68*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """720""")
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        with mode:
71*da0073e9SAndroid Build Coastguard Worker            torch.addmm(T(4, 6), T(4, 5), T(5, 6))
72*da0073e9SAndroid Build Coastguard Worker            torch.addmm(T(4, 1), T(4, 5), T(5, 6))
73*da0073e9SAndroid Build Coastguard Worker            torch.addmm(T(6), T(4, 5), T(5, 6))
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        # 4 * 6 * 2 * 5 = 240
76*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """720""")
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        with mode:
79*da0073e9SAndroid Build Coastguard Worker            torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        # 3 * 4 * 6 * 2 * 5 = 720
82*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """720""")
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        with mode:
85*da0073e9SAndroid Build Coastguard Worker            torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        # out_image_size = 2 * 5 * 5
88*da0073e9SAndroid Build Coastguard Worker        # kernel_size = 4 * 4
89*da0073e9SAndroid Build Coastguard Worker        # c_out = 6
90*da0073e9SAndroid Build Coastguard Worker        # c_in = 3
91*da0073e9SAndroid Build Coastguard Worker        # out_image_size * kernel_size * c_out * 2 * c_in
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        # NB: I don't think this properly accounts for padding?
94*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """28800""")
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        with mode:
97*da0073e9SAndroid Build Coastguard Worker            torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        # out_image_size = 2 * 5
100*da0073e9SAndroid Build Coastguard Worker        # kernel_size = 4
101*da0073e9SAndroid Build Coastguard Worker        # c_out = 6
102*da0073e9SAndroid Build Coastguard Worker        # c_in = 3
103*da0073e9SAndroid Build Coastguard Worker        # out_image_size * kernel_size * c_out * 2 * c_in
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        # NB: I don't think this properly accounts for padding?
106*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """1440""")
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    def test_backward(self):
109*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
110*da0073e9SAndroid Build Coastguard Worker            a = T(4, 5, requires_grad=True)
111*da0073e9SAndroid Build Coastguard Worker            a = torch.mm(a, T(5, 6))
112*da0073e9SAndroid Build Coastguard Worker            a = a.unsqueeze(0).expand(7, 4, 6)
113*da0073e9SAndroid Build Coastguard Worker            a = torch.bmm(a, T(7, 6, 7))
114*da0073e9SAndroid Build Coastguard Worker            a.sum().backward()
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """5184""")
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def test_backward_reset(self):
119*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
120*da0073e9SAndroid Build Coastguard Worker            a = T(4, 5, requires_grad=True)
121*da0073e9SAndroid Build Coastguard Worker            a.mm(a.t()).sum().backward()
122*da0073e9SAndroid Build Coastguard Worker            a.mm(a.t()).sum().backward()
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """960""")
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    def test_torchscript(self):
127*da0073e9SAndroid Build Coastguard Worker        def foo(x):
128*da0073e9SAndroid Build Coastguard Worker            return torch.mm(x, x)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
131*da0073e9SAndroid Build Coastguard Worker            foo(T(5, 5))
132*da0073e9SAndroid Build Coastguard Worker        unscripted_flops = get_total_flops(mode)
133*da0073e9SAndroid Build Coastguard Worker        ts_foo = torch.jit.script(foo)
134*da0073e9SAndroid Build Coastguard Worker        with mode:
135*da0073e9SAndroid Build Coastguard Worker            ts_foo(T(5, 5))
136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unscripted_flops, get_total_flops(mode))
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    def test_autograd_op(self):
139*da0073e9SAndroid Build Coastguard Worker        class _CustomOp(torch.autograd.Function):
140*da0073e9SAndroid Build Coastguard Worker            @staticmethod
141*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input: torch.Tensor) -> torch.Tensor:
142*da0073e9SAndroid Build Coastguard Worker                return torch.mm(input, input)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker            @staticmethod
145*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
146*da0073e9SAndroid Build Coastguard Worker                return torch.mm(grad_output, grad_output) + torch.mm(
147*da0073e9SAndroid Build Coastguard Worker                    grad_output, grad_output
148*da0073e9SAndroid Build Coastguard Worker                )
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        a = T(5, 5, requires_grad=True)
151*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
152*da0073e9SAndroid Build Coastguard Worker            a = _CustomOp.apply(a)
153*da0073e9SAndroid Build Coastguard Worker            a.sum().backward()
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """750""")
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def test_conv_backwards_as_decomposition(self):
158*da0073e9SAndroid Build Coastguard Worker        # [conv backwards decomposition as conv forwards]
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        class onlyConvs(torch.autograd.Function):
161*da0073e9SAndroid Build Coastguard Worker            @staticmethod
162*da0073e9SAndroid Build Coastguard Worker            def forward(inp, weight, transposed):
163*da0073e9SAndroid Build Coastguard Worker                if not transposed:
164*da0073e9SAndroid Build Coastguard Worker                    return F.conv1d(inp, weight)
165*da0073e9SAndroid Build Coastguard Worker                else:
166*da0073e9SAndroid Build Coastguard Worker                    return F.conv_transpose1d(inp, weight)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker            @staticmethod
169*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output):
170*da0073e9SAndroid Build Coastguard Worker                inp, weight, transposed = inputs
171*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(inp, weight)
172*da0073e9SAndroid Build Coastguard Worker                ctx.transposed = transposed
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker            @staticmethod
175*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_out):
176*da0073e9SAndroid Build Coastguard Worker                inp, weight = ctx.saved_tensors
177*da0073e9SAndroid Build Coastguard Worker                if not ctx.transposed:
178*da0073e9SAndroid Build Coastguard Worker                    grad_inp = F.conv_transpose1d(grad_out, weight)
179*da0073e9SAndroid Build Coastguard Worker                    grad_weight = F.conv1d(inp, grad_out)
180*da0073e9SAndroid Build Coastguard Worker                    return grad_inp, grad_weight, None
181*da0073e9SAndroid Build Coastguard Worker                else:
182*da0073e9SAndroid Build Coastguard Worker                    grad_inp = F.conv1d(grad_out, weight)
183*da0073e9SAndroid Build Coastguard Worker                    grad_weight = F.conv1d(
184*da0073e9SAndroid Build Coastguard Worker                        grad_out.transpose(1, 0), inp.transpose(1, 0)
185*da0073e9SAndroid Build Coastguard Worker                    )
186*da0073e9SAndroid Build Coastguard Worker                    return grad_inp, grad_weight.transpose(1, 0), None
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        from torch.func import grad
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 16, dtype=torch.float64)
191*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(3, 4, 4, dtype=torch.float64)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        def boring_conv(x, weight, transposed):
194*da0073e9SAndroid Build Coastguard Worker            if not transposed:
195*da0073e9SAndroid Build Coastguard Worker                return F.conv1d(x, weight).pow(2).sum()
196*da0073e9SAndroid Build Coastguard Worker            else:
197*da0073e9SAndroid Build Coastguard Worker                return F.conv_transpose1d(x, weight).pow(2).sum()
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        def only_convs(x, weight, transposed):
200*da0073e9SAndroid Build Coastguard Worker            return onlyConvs.apply(x, weight, transposed).pow(2).sum()
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
203*da0073e9SAndroid Build Coastguard Worker        fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(boring_grads, fun_grads)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    def test_convs(self):
208*da0073e9SAndroid Build Coastguard Worker        def assert_equivalence(f, expected_forward=None):
209*da0073e9SAndroid Build Coastguard Worker            with FlopCounterMode() as mode:
210*da0073e9SAndroid Build Coastguard Worker                f()
211*da0073e9SAndroid Build Coastguard Worker            conv_forward_flops = mode.get_flop_counts()["Global"][
212*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.convolution
213*da0073e9SAndroid Build Coastguard Worker            ]
214*da0073e9SAndroid Build Coastguard Worker            conv_backward_flops = mode.get_flop_counts()["Global"][
215*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.convolution_backward
216*da0073e9SAndroid Build Coastguard Worker            ]
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
219*da0073e9SAndroid Build Coastguard Worker            if expected_forward is not None:
220*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(conv_forward_flops, expected_forward)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 2, 2, requires_grad=True)
223*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(1, 1, 2, 2, requires_grad=True)
224*da0073e9SAndroid Build Coastguard Worker        assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 2, 2, requires_grad=True)
227*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(1, 1, 1, 1, requires_grad=True)
228*da0073e9SAndroid Build Coastguard Worker        assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        for in_channels, out_channels, groups in [
231*da0073e9SAndroid Build Coastguard Worker            (1, 1, 1),
232*da0073e9SAndroid Build Coastguard Worker            (1, 3, 1),
233*da0073e9SAndroid Build Coastguard Worker            (3, 1, 1),
234*da0073e9SAndroid Build Coastguard Worker            (3, 7, 1),
235*da0073e9SAndroid Build Coastguard Worker            (2, 4, 2),
236*da0073e9SAndroid Build Coastguard Worker            (4, 2, 2),
237*da0073e9SAndroid Build Coastguard Worker        ]:
238*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
239*da0073e9SAndroid Build Coastguard Worker            weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
240*da0073e9SAndroid Build Coastguard Worker            assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
241*da0073e9SAndroid Build Coastguard Worker            transposed_weight = torch.randn(
242*da0073e9SAndroid Build Coastguard Worker                in_channels, out_channels, 2, 2, requires_grad=True
243*da0073e9SAndroid Build Coastguard Worker            )
244*da0073e9SAndroid Build Coastguard Worker            assert_equivalence(
245*da0073e9SAndroid Build Coastguard Worker                lambda: F.conv_transpose2d(x, transposed_weight).sum().backward()
246*da0073e9SAndroid Build Coastguard Worker            )
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
249*da0073e9SAndroid Build Coastguard Worker    def test_module(self):
250*da0073e9SAndroid Build Coastguard Worker        resnet18 = torchvision_models.resnet18()
251*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode(resnet18) as mode:
252*da0073e9SAndroid Build Coastguard Worker            a = T(1, 3, 224, 224, requires_grad=True)
253*da0073e9SAndroid Build Coastguard Worker            resnet18(a).sum().backward()
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """10884440064""")
256*da0073e9SAndroid Build Coastguard Worker        layer1_conv_flops = mode.flop_counts["ResNet.layer1"][
257*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.convolution
258*da0073e9SAndroid Build Coastguard Worker        ]
259*da0073e9SAndroid Build Coastguard Worker        layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][
260*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.convolution_backward
261*da0073e9SAndroid Build Coastguard Worker        ]
262*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
263*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_loop(self):
266*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 4, 30, 2)
267*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
270*da0073e9SAndroid Build Coastguard Worker            for i in range(50):
271*da0073e9SAndroid Build Coastguard Worker                out = model(x)
272*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
273*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    def test_custom(self):
276*da0073e9SAndroid Build Coastguard Worker        mode = FlopCounterMode(
277*da0073e9SAndroid Build Coastguard Worker            custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5}
278*da0073e9SAndroid Build Coastguard Worker        )
279*da0073e9SAndroid Build Coastguard Worker        with mode:
280*da0073e9SAndroid Build Coastguard Worker            a = T(4, 5)
281*da0073e9SAndroid Build Coastguard Worker            a + a
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """5""")
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        def count(*args, out_val):
286*da0073e9SAndroid Build Coastguard Worker            return out_val.numel()
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        count._get_raw = True
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
291*da0073e9SAndroid Build Coastguard Worker        with mode:
292*da0073e9SAndroid Build Coastguard Worker            a = T(4, 5)
293*da0073e9SAndroid Build Coastguard Worker            a + a
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """20""")
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_noop(self):
298*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
299*da0073e9SAndroid Build Coastguard Worker            T(4, 5).cos()
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
302*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
303*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FLASH_ATTENTION
304*da0073e9SAndroid Build Coastguard Worker        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
305*da0073e9SAndroid Build Coastguard Worker        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
306*da0073e9SAndroid Build Coastguard Worker    )
307*da0073e9SAndroid Build Coastguard Worker    def test_sdpa(self):
308*da0073e9SAndroid Build Coastguard Worker        batch_size = 4
309*da0073e9SAndroid Build Coastguard Worker        n_heads = 8
310*da0073e9SAndroid Build Coastguard Worker        seq_len_q = 128
311*da0073e9SAndroid Build Coastguard Worker        seq_len_k = 256
312*da0073e9SAndroid Build Coastguard Worker        head_dim = 64
313*da0073e9SAndroid Build Coastguard Worker        head_dim_v = 64
314*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(0)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        def get_flops(
319*da0073e9SAndroid Build Coastguard Worker            batch_size,
320*da0073e9SAndroid Build Coastguard Worker            n_heads,
321*da0073e9SAndroid Build Coastguard Worker            seq_len_q,
322*da0073e9SAndroid Build Coastguard Worker            seq_len_k,
323*da0073e9SAndroid Build Coastguard Worker            head_dim,
324*da0073e9SAndroid Build Coastguard Worker            head_dim_v,
325*da0073e9SAndroid Build Coastguard Worker            dtype,
326*da0073e9SAndroid Build Coastguard Worker            backend,
327*da0073e9SAndroid Build Coastguard Worker            with_backward=False,
328*da0073e9SAndroid Build Coastguard Worker        ):
329*da0073e9SAndroid Build Coastguard Worker            query = torch.randn(
330*da0073e9SAndroid Build Coastguard Worker                batch_size,
331*da0073e9SAndroid Build Coastguard Worker                n_heads,
332*da0073e9SAndroid Build Coastguard Worker                seq_len_q,
333*da0073e9SAndroid Build Coastguard Worker                head_dim,
334*da0073e9SAndroid Build Coastguard Worker                device="cuda",
335*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
336*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
337*da0073e9SAndroid Build Coastguard Worker            )
338*da0073e9SAndroid Build Coastguard Worker            key = torch.randn(
339*da0073e9SAndroid Build Coastguard Worker                batch_size,
340*da0073e9SAndroid Build Coastguard Worker                n_heads,
341*da0073e9SAndroid Build Coastguard Worker                seq_len_k,
342*da0073e9SAndroid Build Coastguard Worker                head_dim,
343*da0073e9SAndroid Build Coastguard Worker                device="cuda",
344*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
345*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
346*da0073e9SAndroid Build Coastguard Worker            )
347*da0073e9SAndroid Build Coastguard Worker            value = torch.randn(
348*da0073e9SAndroid Build Coastguard Worker                batch_size,
349*da0073e9SAndroid Build Coastguard Worker                n_heads,
350*da0073e9SAndroid Build Coastguard Worker                seq_len_k,
351*da0073e9SAndroid Build Coastguard Worker                head_dim_v,
352*da0073e9SAndroid Build Coastguard Worker                device="cuda",
353*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
354*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
355*da0073e9SAndroid Build Coastguard Worker            )
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker            if backend == "math":
358*da0073e9SAndroid Build Coastguard Worker                backend = torch.backends.cuda.sdp_kernel(
359*da0073e9SAndroid Build Coastguard Worker                    enable_flash=False, enable_math=True, enable_mem_efficient=False
360*da0073e9SAndroid Build Coastguard Worker                )
361*da0073e9SAndroid Build Coastguard Worker            elif backend == "flash":
362*da0073e9SAndroid Build Coastguard Worker                backend = torch.backends.cuda.sdp_kernel(
363*da0073e9SAndroid Build Coastguard Worker                    enable_flash=True, enable_math=False, enable_mem_efficient=False
364*da0073e9SAndroid Build Coastguard Worker                )
365*da0073e9SAndroid Build Coastguard Worker            elif backend == "mem_efficient":
366*da0073e9SAndroid Build Coastguard Worker                backend = torch.backends.cuda.sdp_kernel(
367*da0073e9SAndroid Build Coastguard Worker                    enable_flash=False, enable_math=False, enable_mem_efficient=True
368*da0073e9SAndroid Build Coastguard Worker                )
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker            mode = FlopCounterMode()
371*da0073e9SAndroid Build Coastguard Worker            with backend, mode:
372*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(
373*da0073e9SAndroid Build Coastguard Worker                    query, key, value, dropout_p=0, is_causal=True
374*da0073e9SAndroid Build Coastguard Worker                )
375*da0073e9SAndroid Build Coastguard Worker                if with_backward:
376*da0073e9SAndroid Build Coastguard Worker                    out.sum().backward()
377*da0073e9SAndroid Build Coastguard Worker            return int(get_total_flops(mode))
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        # Sets seq_len_q == seq_len_k and dim_q == dim_v
380*da0073e9SAndroid Build Coastguard Worker        run_uniform_flops = functools.partial(
381*da0073e9SAndroid Build Coastguard Worker            get_flops,
382*da0073e9SAndroid Build Coastguard Worker            batch_size,
383*da0073e9SAndroid Build Coastguard Worker            n_heads,
384*da0073e9SAndroid Build Coastguard Worker            seq_len_q,
385*da0073e9SAndroid Build Coastguard Worker            seq_len_q,
386*da0073e9SAndroid Build Coastguard Worker            head_dim,
387*da0073e9SAndroid Build Coastguard Worker            head_dim,
388*da0073e9SAndroid Build Coastguard Worker            dtype,
389*da0073e9SAndroid Build Coastguard Worker        )
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        flops = [
392*da0073e9SAndroid Build Coastguard Worker            run_uniform_flops(backend, with_backward=False)
393*da0073e9SAndroid Build Coastguard Worker            for backend in ["math", "flash", "mem_efficient"]
394*da0073e9SAndroid Build Coastguard Worker        ]
395*da0073e9SAndroid Build Coastguard Worker        flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flops_fw_math, flops_fw_flash)
397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flops_fw_math, flops_fw_efficient)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(flops_fw_math), """134217728""")
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        flops = [
402*da0073e9SAndroid Build Coastguard Worker            run_uniform_flops(backend, with_backward=True)
403*da0073e9SAndroid Build Coastguard Worker            for backend in ["math", "flash", "mem_efficient"]
404*da0073e9SAndroid Build Coastguard Worker        ]
405*da0073e9SAndroid Build Coastguard Worker        flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        run_nonuniform_flops = functools.partial(
411*da0073e9SAndroid Build Coastguard Worker            get_flops,
412*da0073e9SAndroid Build Coastguard Worker            batch_size,
413*da0073e9SAndroid Build Coastguard Worker            n_heads,
414*da0073e9SAndroid Build Coastguard Worker            seq_len_q,
415*da0073e9SAndroid Build Coastguard Worker            seq_len_k,
416*da0073e9SAndroid Build Coastguard Worker            head_dim,
417*da0073e9SAndroid Build Coastguard Worker            head_dim_v,
418*da0073e9SAndroid Build Coastguard Worker            dtype,
419*da0073e9SAndroid Build Coastguard Worker        )
420*da0073e9SAndroid Build Coastguard Worker        # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
421*da0073e9SAndroid Build Coastguard Worker        non_uniform_backends = ["math", "mem_efficient"]
422*da0073e9SAndroid Build Coastguard Worker        flops = [
423*da0073e9SAndroid Build Coastguard Worker            run_nonuniform_flops(backend, with_backward=False)
424*da0073e9SAndroid Build Coastguard Worker            for backend in non_uniform_backends
425*da0073e9SAndroid Build Coastguard Worker        ]
426*da0073e9SAndroid Build Coastguard Worker        flops_fw_math, flops_fw_efficient = flops
427*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flops_fw_math, flops_fw_efficient)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(flops_fw_math), """268435456""")
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        flops = [
432*da0073e9SAndroid Build Coastguard Worker            run_nonuniform_flops(backend, with_backward=True)
433*da0073e9SAndroid Build Coastguard Worker            for backend in non_uniform_backends
434*da0073e9SAndroid Build Coastguard Worker        ]
435*da0073e9SAndroid Build Coastguard Worker        flops_fw_bw_math, flops_fw_bw_efficient = flops
436*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
437*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Nested tensor
440*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
441*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
442*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FLASH_ATTENTION
443*da0073e9SAndroid Build Coastguard Worker        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
444*da0073e9SAndroid Build Coastguard Worker        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
445*da0073e9SAndroid Build Coastguard Worker    )
446*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_nested_tensor(self):
447*da0073e9SAndroid Build Coastguard Worker        def get_flops(q, k, v, backend, with_backward=False):
448*da0073e9SAndroid Build Coastguard Worker            mode = FlopCounterMode()
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker            if backend == "math":
451*da0073e9SAndroid Build Coastguard Worker                backend = torch.backends.cuda.sdp_kernel(
452*da0073e9SAndroid Build Coastguard Worker                    enable_flash=False, enable_math=True, enable_mem_efficient=False
453*da0073e9SAndroid Build Coastguard Worker                )
454*da0073e9SAndroid Build Coastguard Worker            elif backend == "flash":
455*da0073e9SAndroid Build Coastguard Worker                backend = torch.backends.cuda.sdp_kernel(
456*da0073e9SAndroid Build Coastguard Worker                    enable_flash=True, enable_math=False, enable_mem_efficient=False
457*da0073e9SAndroid Build Coastguard Worker                )
458*da0073e9SAndroid Build Coastguard Worker            elif backend == "mem_efficient":
459*da0073e9SAndroid Build Coastguard Worker                backend = torch.backends.cuda.sdp_kernel(
460*da0073e9SAndroid Build Coastguard Worker                    enable_flash=False, enable_math=False, enable_mem_efficient=True
461*da0073e9SAndroid Build Coastguard Worker                )
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker            with backend, mode:
464*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(
465*da0073e9SAndroid Build Coastguard Worker                    q, k, v, dropout_p=0, is_causal=True
466*da0073e9SAndroid Build Coastguard Worker                )
467*da0073e9SAndroid Build Coastguard Worker                if with_backward:
468*da0073e9SAndroid Build Coastguard Worker                    if out.is_nested:
469*da0073e9SAndroid Build Coastguard Worker                        out.values().sum().backward()
470*da0073e9SAndroid Build Coastguard Worker                    else:
471*da0073e9SAndroid Build Coastguard Worker                        out.sum().backward()
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker            return int(get_total_flops(mode))
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        def get_nested_inputs(
476*da0073e9SAndroid Build Coastguard Worker            batch_size,
477*da0073e9SAndroid Build Coastguard Worker            n_heads,
478*da0073e9SAndroid Build Coastguard Worker            max_seq_len_q,
479*da0073e9SAndroid Build Coastguard Worker            max_seq_len_k,
480*da0073e9SAndroid Build Coastguard Worker            head_dim,
481*da0073e9SAndroid Build Coastguard Worker            head_dim_v,
482*da0073e9SAndroid Build Coastguard Worker            dtype,
483*da0073e9SAndroid Build Coastguard Worker        ):
484*da0073e9SAndroid Build Coastguard Worker            q_lengths = torch.tensor(
485*da0073e9SAndroid Build Coastguard Worker                [
486*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_q // 4,
487*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_q // 4 * 2,
488*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_q // 4 * 3,
489*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_q // 4 * 4,
490*da0073e9SAndroid Build Coastguard Worker                ]
491*da0073e9SAndroid Build Coastguard Worker            )
492*da0073e9SAndroid Build Coastguard Worker            k_lengths = torch.tensor(
493*da0073e9SAndroid Build Coastguard Worker                [
494*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_k // 4,
495*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_k // 4 * 2,
496*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_k // 4 * 3,
497*da0073e9SAndroid Build Coastguard Worker                    max_seq_len_k // 4 * 4,
498*da0073e9SAndroid Build Coastguard Worker                ]
499*da0073e9SAndroid Build Coastguard Worker            )
500*da0073e9SAndroid Build Coastguard Worker            q_offsets, k_offsets = (
501*da0073e9SAndroid Build Coastguard Worker                torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda()
502*da0073e9SAndroid Build Coastguard Worker                for lengths in (q_lengths, k_lengths)
503*da0073e9SAndroid Build Coastguard Worker            )
504*da0073e9SAndroid Build Coastguard Worker            q_values = torch.randn(
505*da0073e9SAndroid Build Coastguard Worker                q_offsets[-1],
506*da0073e9SAndroid Build Coastguard Worker                head_dim * n_heads,
507*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
508*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
509*da0073e9SAndroid Build Coastguard Worker                device="cuda",
510*da0073e9SAndroid Build Coastguard Worker            )
511*da0073e9SAndroid Build Coastguard Worker            k_values = torch.randn(
512*da0073e9SAndroid Build Coastguard Worker                k_offsets[-1],
513*da0073e9SAndroid Build Coastguard Worker                head_dim * n_heads,
514*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
515*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
516*da0073e9SAndroid Build Coastguard Worker                device="cuda",
517*da0073e9SAndroid Build Coastguard Worker            )
518*da0073e9SAndroid Build Coastguard Worker            v_values = torch.randn(
519*da0073e9SAndroid Build Coastguard Worker                k_offsets[-1],
520*da0073e9SAndroid Build Coastguard Worker                head_dim_v * n_heads,
521*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
522*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
523*da0073e9SAndroid Build Coastguard Worker                device="cuda",
524*da0073e9SAndroid Build Coastguard Worker            )
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker            q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets)
527*da0073e9SAndroid Build Coastguard Worker            k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets)
528*da0073e9SAndroid Build Coastguard Worker            v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets)
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker            q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
531*da0073e9SAndroid Build Coastguard Worker            k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
532*da0073e9SAndroid Build Coastguard Worker            v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2)
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker            return q, k, v
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        def get_dense_flops(q, k, v, backend, with_backward=False):
537*da0073e9SAndroid Build Coastguard Worker            def split_tensor(x):
538*da0073e9SAndroid Build Coastguard Worker                return (
539*da0073e9SAndroid Build Coastguard Worker                    y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True)
540*da0073e9SAndroid Build Coastguard Worker                    for y in x.transpose(1, 2).unbind(0)
541*da0073e9SAndroid Build Coastguard Worker                )
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker            q_tensors = split_tensor(q)
544*da0073e9SAndroid Build Coastguard Worker            k_tensors = split_tensor(k)
545*da0073e9SAndroid Build Coastguard Worker            v_tensors = split_tensor(v)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker            flops = 0
548*da0073e9SAndroid Build Coastguard Worker            for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors):
549*da0073e9SAndroid Build Coastguard Worker                flops += get_flops(
550*da0073e9SAndroid Build Coastguard Worker                    q_i, k_i, v_i, backend=backend, with_backward=with_backward
551*da0073e9SAndroid Build Coastguard Worker                )
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker            return flops
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker        uniform_config = {
556*da0073e9SAndroid Build Coastguard Worker            "batch_size": 4,
557*da0073e9SAndroid Build Coastguard Worker            "n_heads": 8,
558*da0073e9SAndroid Build Coastguard Worker            "max_seq_len_q": 128,
559*da0073e9SAndroid Build Coastguard Worker            "max_seq_len_k": 128,
560*da0073e9SAndroid Build Coastguard Worker            "head_dim": 64,
561*da0073e9SAndroid Build Coastguard Worker            "head_dim_v": 64,
562*da0073e9SAndroid Build Coastguard Worker            "dtype": torch.float16,
563*da0073e9SAndroid Build Coastguard Worker        }
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors.
566*da0073e9SAndroid Build Coastguard Worker        differing_config = {
567*da0073e9SAndroid Build Coastguard Worker            "batch_size": 4,
568*da0073e9SAndroid Build Coastguard Worker            "n_heads": 8,
569*da0073e9SAndroid Build Coastguard Worker            "max_seq_len_q": 128,
570*da0073e9SAndroid Build Coastguard Worker            "max_seq_len_k": 256,
571*da0073e9SAndroid Build Coastguard Worker            "head_dim": 64,
572*da0073e9SAndroid Build Coastguard Worker            "head_dim_v": 64,
573*da0073e9SAndroid Build Coastguard Worker            "dtype": torch.float16,
574*da0073e9SAndroid Build Coastguard Worker        }
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
577*da0073e9SAndroid Build Coastguard Worker            get_dense_flops(
578*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
579*da0073e9SAndroid Build Coastguard Worker                backend="flash",
580*da0073e9SAndroid Build Coastguard Worker                with_backward=False,
581*da0073e9SAndroid Build Coastguard Worker            ),
582*da0073e9SAndroid Build Coastguard Worker            get_flops(
583*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
584*da0073e9SAndroid Build Coastguard Worker                backend="flash",
585*da0073e9SAndroid Build Coastguard Worker                with_backward=False,
586*da0073e9SAndroid Build Coastguard Worker            ),
587*da0073e9SAndroid Build Coastguard Worker        )
588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
589*da0073e9SAndroid Build Coastguard Worker            get_dense_flops(
590*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
591*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
592*da0073e9SAndroid Build Coastguard Worker                with_backward=False,
593*da0073e9SAndroid Build Coastguard Worker            ),
594*da0073e9SAndroid Build Coastguard Worker            get_flops(
595*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
596*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
597*da0073e9SAndroid Build Coastguard Worker                with_backward=False,
598*da0073e9SAndroid Build Coastguard Worker            ),
599*da0073e9SAndroid Build Coastguard Worker        )
600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
601*da0073e9SAndroid Build Coastguard Worker            get_dense_flops(
602*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**differing_config),
603*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
604*da0073e9SAndroid Build Coastguard Worker                with_backward=False,
605*da0073e9SAndroid Build Coastguard Worker            ),
606*da0073e9SAndroid Build Coastguard Worker            get_flops(
607*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**differing_config),
608*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
609*da0073e9SAndroid Build Coastguard Worker                with_backward=False,
610*da0073e9SAndroid Build Coastguard Worker            ),
611*da0073e9SAndroid Build Coastguard Worker        )
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
614*da0073e9SAndroid Build Coastguard Worker            get_dense_flops(
615*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
616*da0073e9SAndroid Build Coastguard Worker                backend="flash",
617*da0073e9SAndroid Build Coastguard Worker                with_backward=True,
618*da0073e9SAndroid Build Coastguard Worker            ),
619*da0073e9SAndroid Build Coastguard Worker            get_flops(
620*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
621*da0073e9SAndroid Build Coastguard Worker                backend="flash",
622*da0073e9SAndroid Build Coastguard Worker                with_backward=True,
623*da0073e9SAndroid Build Coastguard Worker            ),
624*da0073e9SAndroid Build Coastguard Worker        )
625*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
626*da0073e9SAndroid Build Coastguard Worker            get_dense_flops(
627*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
628*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
629*da0073e9SAndroid Build Coastguard Worker                with_backward=True,
630*da0073e9SAndroid Build Coastguard Worker            ),
631*da0073e9SAndroid Build Coastguard Worker            get_flops(
632*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**uniform_config),
633*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
634*da0073e9SAndroid Build Coastguard Worker                with_backward=True,
635*da0073e9SAndroid Build Coastguard Worker            ),
636*da0073e9SAndroid Build Coastguard Worker        )
637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
638*da0073e9SAndroid Build Coastguard Worker            get_dense_flops(
639*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**differing_config),
640*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
641*da0073e9SAndroid Build Coastguard Worker                with_backward=True,
642*da0073e9SAndroid Build Coastguard Worker            ),
643*da0073e9SAndroid Build Coastguard Worker            get_flops(
644*da0073e9SAndroid Build Coastguard Worker                *get_nested_inputs(**differing_config),
645*da0073e9SAndroid Build Coastguard Worker                backend="mem_efficient",
646*da0073e9SAndroid Build Coastguard Worker                with_backward=True,
647*da0073e9SAndroid Build Coastguard Worker            ),
648*da0073e9SAndroid Build Coastguard Worker        )
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Nested tensor
651*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
652*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
653*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
654*da0073e9SAndroid Build Coastguard Worker        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
655*da0073e9SAndroid Build Coastguard Worker    )
656*da0073e9SAndroid Build Coastguard Worker    def test_nested_attention_fake_tensors(self):
657*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
658*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
659*da0073e9SAndroid Build Coastguard Worker        max_seqlen = 40
660*da0073e9SAndroid Build Coastguard Worker        with FakeTensorMode() as fake_mode:
661*da0073e9SAndroid Build Coastguard Worker            fake_x = fake_mode.from_tensor(x)
662*da0073e9SAndroid Build Coastguard Worker            fake_offsets = fake_mode.from_tensor(offsets)
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker            with FlopCounterMode() as fake_flop_counter_mode:
665*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._flash_attention_forward(
666*da0073e9SAndroid Build Coastguard Worker                    fake_x,
667*da0073e9SAndroid Build Coastguard Worker                    fake_x,
668*da0073e9SAndroid Build Coastguard Worker                    fake_x,
669*da0073e9SAndroid Build Coastguard Worker                    fake_offsets,
670*da0073e9SAndroid Build Coastguard Worker                    fake_offsets,
671*da0073e9SAndroid Build Coastguard Worker                    max_seqlen,
672*da0073e9SAndroid Build Coastguard Worker                    max_seqlen,
673*da0073e9SAndroid Build Coastguard Worker                    0.0,
674*da0073e9SAndroid Build Coastguard Worker                    False,
675*da0073e9SAndroid Build Coastguard Worker                    False,
676*da0073e9SAndroid Build Coastguard Worker                )
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as real_flop_counter_mode:
681*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._flash_attention_forward(
682*da0073e9SAndroid Build Coastguard Worker                dense_x,
683*da0073e9SAndroid Build Coastguard Worker                dense_x,
684*da0073e9SAndroid Build Coastguard Worker                dense_x,
685*da0073e9SAndroid Build Coastguard Worker                None,
686*da0073e9SAndroid Build Coastguard Worker                None,
687*da0073e9SAndroid Build Coastguard Worker                max_seqlen,
688*da0073e9SAndroid Build Coastguard Worker                max_seqlen,
689*da0073e9SAndroid Build Coastguard Worker                0.0,
690*da0073e9SAndroid Build Coastguard Worker                False,
691*da0073e9SAndroid Build Coastguard Worker                False,
692*da0073e9SAndroid Build Coastguard Worker            )
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker    def test_addmm_out(self):
698*da0073e9SAndroid Build Coastguard Worker        def f(x):
699*da0073e9SAndroid Build Coastguard Worker            y = torch.zeros(10, 10)
700*da0073e9SAndroid Build Coastguard Worker            return torch.mm(x, x, out=y)
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
703*da0073e9SAndroid Build Coastguard Worker            f(torch.randn(10, 10))
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """2000""")
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker    def test_hook_registration(self):
708*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(100, 100)
709*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 100)
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
712*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1)
713*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1)
714*da0073e9SAndroid Build Coastguard Worker            model(x).sum().backward()
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0)
717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    def test_pytrees(self):
720*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
721*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
722*da0073e9SAndroid Build Coastguard Worker                x = x["a"].relu_()
723*da0073e9SAndroid Build Coastguard Worker                return {"a": torch.mm(x, x)}
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
726*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
727*da0073e9SAndroid Build Coastguard Worker                super().__init__()
728*da0073e9SAndroid Build Coastguard Worker                self.a = Foo()
729*da0073e9SAndroid Build Coastguard Worker                self.b = Foo()
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
732*da0073e9SAndroid Build Coastguard Worker                return self.b(self.a(x))
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
735*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
736*da0073e9SAndroid Build Coastguard Worker            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
737*da0073e9SAndroid Build Coastguard Worker                "a"
738*da0073e9SAndroid Build Coastguard Worker            ].sum().backward()
739*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
740*da0073e9SAndroid Build Coastguard Worker            (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000"""
741*da0073e9SAndroid Build Coastguard Worker        )
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        class Mod2(torch.nn.Module):
744*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
745*da0073e9SAndroid Build Coastguard Worker                return (torch.mm(x, x),)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker        mod = Mod2()
748*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode() as mode:
749*da0073e9SAndroid Build Coastguard Worker            mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
750*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
751*da0073e9SAndroid Build Coastguard Worker            (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000"""
752*da0073e9SAndroid Build Coastguard Worker        )
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker    def test_warning(self):
755*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Linear(2, 2)
756*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "not needed"):
757*da0073e9SAndroid Build Coastguard Worker            FlopCounterMode(mod)
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker    def test_custom_op(self):
760*da0073e9SAndroid Build Coastguard Worker        from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op("mylib::foo", mutates_args=())
763*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor) -> torch.Tensor:
764*da0073e9SAndroid Build Coastguard Worker            return x.sin()
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        called = 0
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
769*da0073e9SAndroid Build Coastguard Worker            register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker        @register_flop_formula(torch.ops.mylib.foo)
772*da0073e9SAndroid Build Coastguard Worker        def formula(*args, **kwargs):
773*da0073e9SAndroid Build Coastguard Worker            nonlocal called
774*da0073e9SAndroid Build Coastguard Worker            called += 1
775*da0073e9SAndroid Build Coastguard Worker            return 9001
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
778*da0073e9SAndroid Build Coastguard Worker        with FlopCounterMode(display=False) as mode:
779*da0073e9SAndroid Build Coastguard Worker            y = foo(x)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
782*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(get_total_flops(mode), """9001""")
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
786*da0073e9SAndroid Build Coastguard Worker    run_tests()
787