xref: /aosp_15_r20/external/pytorch/test/inductor/test_fused_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import functools
3import itertools
4import math
5
6import torch
7import torch._inductor.config
8import torch.utils.checkpoint
9from torch._dynamo.debug_utils import aot_graph_input_parser
10from torch._dynamo.utils import counters
11from torch._inductor.test_case import run_tests, TestCase
12from torch._inductor.utils import run_and_get_code
13from torch.testing._internal.common_cuda import (
14    PLATFORM_SUPPORTS_FUSED_ATTENTION,
15    SM80OrLater,
16)
17from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
18from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
19
20
21def checkpoint_wrapper(fn):
22    def inner(*args):
23        return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
24
25    return inner
26
27
28class TestSDPAPatternRewriterTemplate(TestCase):
29    use_static_shapes = True
30
31    def _clone_inputs(self, inputs):
32        def clone(x):
33            if not isinstance(x, torch.Tensor):
34                return x
35            return x.clone()
36
37        return [clone(x) for x in inputs]
38
39    def _check_common(
40        self,
41        dot_prod_attention,
42        args1=None,
43        contains=True,
44        atol=1e-5,
45        has_fuse_pattern=True,
46        has_dropout=False,
47        check_train=True,
48        override_check_equal=False,
49        dtype=torch.float,
50        rtol=1.3e-6,
51    ):
52        if args1 is None:
53            tensor_shape = (4, 2, 16, 32)
54            args1 = [
55                torch.randn(tensor_shape, device=self.device, dtype=dtype),
56                torch.randn(tensor_shape, device=self.device, dtype=dtype),
57                torch.randn(tensor_shape, device=self.device, dtype=dtype),
58            ]
59        else:
60            args1 = list(args1)
61        args2 = self._clone_inputs(args1)
62
63        for training in [False, True] if check_train else [False]:
64            for x in itertools.chain(args1[:], args2[:]):
65                if isinstance(x, torch.Tensor) and x.is_floating_point():
66                    x.requires_grad = training
67
68            if not self.use_static_shapes:
69                torch._dynamo.mark_dynamic(args2[0], 0)
70                torch._dynamo.mark_dynamic(args2[1], 0)
71                torch._dynamo.mark_dynamic(args2[2], 0)
72
73            dropout_arg = [training] if has_dropout else []
74            torch.manual_seed(1234)
75            result1 = dot_prod_attention(*(args1 + dropout_arg))
76
77            counters.clear()
78            torch.manual_seed(1234)
79            result2, source_code = run_and_get_code(
80                torch.compile(dot_prod_attention, fullgraph=True),
81                *(args2 + dropout_arg),
82            )
83            source_code = "\n".join(source_code)
84            if has_fuse_pattern:
85                self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1)
86            if contains:
87                # many of the patterns get re-expanded in dispatcher
88                self.assertIn(
89                    "aten._scaled_dot_product",
90                    source_code,
91                )
92
93            # some tests configured with very low dropout where we still want to check equality
94            if not has_dropout or override_check_equal:
95                self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
96
97            if training:
98                result1.sum().backward()
99                result2.sum().backward()
100                for arg1, arg2 in zip(args1, args2):
101                    if (
102                        isinstance(arg1, torch.Tensor)
103                        and arg1.is_floating_point()
104                        and (not has_dropout or override_check_equal)
105                    ):
106                        self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
107
108    @skipIfRocm
109    def _test_sdpa_rewriter_1(self):
110        def dot_prod_attention(
111            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
112        ) -> torch.Tensor:
113            """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
114            return (
115                torch.matmul(query, key.transpose(-2, -1))
116                .div(math.sqrt(key.shape[-1]))
117                .softmax(dim=-1)
118                .matmul(value)
119            )
120
121        for dtype in [torch.float, torch.half]:
122            atol = 0.001
123            rtol = 1.3e-6 if dtype == torch.float else 0.7
124            if self.device == "cpu" and dtype == torch.half:
125                atol = 2e-3
126                rtol = 1e-2
127            self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol)
128            self._check_common(
129                checkpoint_wrapper(dot_prod_attention),
130                dtype=dtype,
131                atol=atol,
132                rtol=rtol,
133            )
134
135    @skipIfRocm
136    @torch._inductor.config.patch("freezing", True)
137    def _test_sdpa_rewriter_1_freezing(self):
138        def dot_prod_attention(
139            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
140        ) -> torch.Tensor:
141            """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
142            return (
143                torch.matmul(query, key.transpose(-2, -1))
144                .div(math.sqrt(key.shape[-1]))
145                .softmax(dim=-1)
146                .matmul(value)
147            )
148
149        for dtype in [torch.float, torch.half]:
150            atol = 0.001
151            rtol = 1.3e-6 if dtype == torch.float else 0.7
152            if self.device == "cpu" and dtype == torch.half:
153                atol = 2e-3
154                rtol = 1e-2
155            with torch.no_grad():
156                self._check_common(
157                    dot_prod_attention,
158                    dtype=dtype,
159                    atol=atol,
160                    rtol=rtol,
161                    check_train=False,
162                )
163
164    @skipIfRocm
165    def _test_insignificant_strides(self):
166        f32 = torch.float32
167
168        # repro taken from https://github.com/pytorch/pytorch/issues/124289
169        # constant_pad_nd is a single element tensor that gets expanded
170
171        def forward(
172            permute_3: "f32[1, 32, 1, 128]",
173            permute_4: "f32[1, 32, 1, 128]",
174            permute_5: "f32[1, 32, 1, 128]",
175            permute_6: "f32[1, 1, 64]",
176            mul_2: "f32[1, 1, 1, 1]",
177        ):
178            cat = torch.ops.aten.cat.default([permute_6, permute_6], 2)
179            permute_6 = None
180            cos = torch.ops.aten.cos.default(cat)
181            sin = torch.ops.aten.sin.default(cat)
182            unsqueeze_10 = torch.ops.aten.unsqueeze.default(cos, 1)
183            cos = None
184            unsqueeze_11 = torch.ops.aten.unsqueeze.default(sin, 1)
185            sin = None
186            mul_5 = torch.ops.aten.mul.Tensor(permute_3, unsqueeze_10)
187            slice_10 = torch.ops.aten.slice.Tensor(permute_3, 3, 0, 64)
188            slice_11 = torch.ops.aten.slice.Tensor(
189                permute_3, 3, 64, 9223372036854775807
190            )
191            permute_3 = None
192            neg = torch.ops.aten.neg.default(slice_11)
193            slice_11 = None
194            cat_1 = torch.ops.aten.cat.default([neg, slice_10], 3)
195            neg = slice_10 = None
196            mul_6 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_11)
197            cat_1 = None
198            add_1 = torch.ops.aten.add.Tensor(mul_5, mul_6)
199            mul_5 = mul_6 = None
200            mul_7 = torch.ops.aten.mul.Tensor(permute_4, unsqueeze_10)
201            unsqueeze_10 = None
202            slice_12 = torch.ops.aten.slice.Tensor(permute_4, 3, 0, 64)
203            slice_13 = torch.ops.aten.slice.Tensor(
204                permute_4, 3, 64, 9223372036854775807
205            )
206            permute_4 = None
207            neg_1 = torch.ops.aten.neg.default(slice_13)
208            slice_13 = None
209            cat_2 = torch.ops.aten.cat.default([neg_1, slice_12], 3)
210            neg_1 = slice_12 = None
211            mul_8 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_11)
212            cat_2 = unsqueeze_11 = None
213            add_2 = torch.ops.aten.add.Tensor(mul_7, mul_8)
214            mul_7 = mul_8 = None
215            slice_14 = torch.ops.aten.slice.Tensor(mul_2, 0, 0, 9223372036854775807)
216            mul_2 = None
217            slice_15 = torch.ops.aten.slice.Tensor(slice_14, 1, 0, 9223372036854775807)
218            slice_14 = None
219            slice_16 = torch.ops.aten.slice.Tensor(slice_15, 2, 0, 9223372036854775807)
220            slice_15 = None
221            constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
222                slice_16, [0, 7], 0.0
223            )
224            slice_16 = None
225            slice_17 = torch.ops.aten.slice.Tensor(constant_pad_nd, -1, 0, 1)
226            constant_pad_nd = None
227            expand_5 = torch.ops.aten.expand.default(slice_17, [1, 32, 1, 1])
228            _scaled_dot_product_efficient_attention = (
229                torch.ops.aten._scaled_dot_product_efficient_attention.default(
230                    add_1, add_2, permute_5, expand_5, True
231                )
232            )
233            return _scaled_dot_product_efficient_attention
234
235        kwargs = aot_graph_input_parser(forward, device="cuda")
236        # runs successfully
237        out_eager = forward(**kwargs)
238        out_c = torch.compile(forward)(**kwargs)
239        # dont compare philox_seed/offset
240        torch.testing.assert_close(out_eager[0:2], out_c[0:2])
241
242    def _test_pattern_fails_with_reuse(self):
243        """
244        This test checks that the replacement is not done
245        when an intermediate result is being used / returned downstream
246        """
247
248        @torch.compile(fullgraph=True)
249        def dot_prod_attention(
250            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
251        ) -> torch.Tensor:
252            attn_weights = (
253                torch.matmul(query, key.transpose(-2, -1))
254                .div(math.sqrt(key.shape[-1]))
255                .softmax(dim=-1)
256            )
257            return attn_weights.matmul(value), attn_weights
258
259        tensor_shape = (2, 4, 8, 16)
260        args = [
261            torch.randn(tensor_shape, device=self.device),
262            torch.randn(tensor_shape, device=self.device),
263            torch.randn(tensor_shape, device=self.device),
264        ]
265        _, (source_code,) = run_and_get_code(dot_prod_attention, *args)
266        self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code)
267
268    @skipIfRocm
269    def _test_sdpa_rewriter_2(self):
270        def dot_prod_attention(
271            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
272        ) -> torch.Tensor:
273            return (
274                torch.matmul(query, key.transpose(-2, -1))
275                .mul(1.0 / math.sqrt(key.shape[-1]))
276                .softmax(dim=-1)
277                .matmul(value)
278            )
279
280        self._check_common(dot_prod_attention)
281        self._check_common(checkpoint_wrapper(dot_prod_attention))
282
283    @skipIfRocm  # AssertionError: expected size 4==4, stride 32==64 at dim=0
284    def _test_sdpa_rewriter_3(self):
285        def dot_prod_attention(
286            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool
287        ) -> torch.Tensor:
288            return torch.nn.functional.dropout(
289                torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1),
290                p=0.4,
291                training=training,
292                inplace=False,
293            ).matmul(value)
294
295        self._check_common(dot_prod_attention, contains=False, has_dropout=True)
296        self._check_common(
297            checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True
298        )
299
300    @skipIfRocm  # AssertionError: expected size 4==4, stride 32==64 at dim=0
301    def _test_sdpa_rewriter_4(self):
302        def dot_prod_attention(
303            query: torch.Tensor,
304            key: torch.Tensor,
305            value: torch.Tensor,
306            training: bool,
307        ) -> torch.Tensor:
308            return torch.nn.functional.dropout(
309                torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1),
310                p=0.2,
311                inplace=False,
312                training=training,
313            ).matmul(value)
314
315        self._check_common(dot_prod_attention, contains=False, has_dropout=True)
316        self._check_common(
317            checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True
318        )
319
320    def _test_sdpa_rewriter_5(self):
321        def sfdp_pattern_5_v1(query, key, value):
322            attn_mask = torch.ones(
323                query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
324            ).tril(diagonal=0)
325            attn_mask = attn_mask.masked_fill(
326                torch.logical_not(attn_mask), -float("inf")
327            )
328            attn_weight = torch.softmax(
329                (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
330                dim=-1,
331            )
332            return attn_weight @ value
333
334        def sfdp_pattern_5_v2(query, key, value):
335            # https://github.com/pytorch/pytorch/issues/100318.
336            attn_mask = torch.zeros(
337                query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
338            ).bool()
339            attn_weight = torch.softmax(
340                (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
341                dim=-1,
342            )
343            return attn_weight @ value
344
345        self._check_common(sfdp_pattern_5_v1, contains=False)
346        self._check_common(checkpoint_wrapper(sfdp_pattern_5_v1), contains=False)
347        self._check_common(sfdp_pattern_5_v2, contains=False)
348        self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False)
349
350    @skipIfRocm
351    def _test_sdpa_rewriter_6(self):
352        def sfdp_pattern_6(query, key, value, training):
353            attn_mask = torch.ones(
354                query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
355            ).tril(diagonal=0)
356            attn_mask = attn_mask.masked_fill(
357                torch.logical_not(attn_mask), -float("inf")
358            )
359            attn_weight = torch.softmax(
360                (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
361                dim=-1,
362            )
363            attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training)
364            return attn_weight @ value
365
366        self._check_common(sfdp_pattern_6, contains=False, has_dropout=True)
367        self._check_common(
368            checkpoint_wrapper(sfdp_pattern_6), contains=False, has_dropout=True
369        )
370
371    @skipIfRocm
372    def _test_sdpa_rewriter_7(self):
373        def sfdp_pattern_7(query, key, value, training):
374            q = query.permute(0, 2, 1, 3)
375            k = key.permute(0, 2, 1, 3)
376            v = value.permute(0, 2, 1, 3)
377            div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
378            div = div.to(torch.float32)
379            attn_weight = torch.softmax(div, dim=-1)
380            # Set to False
381            attn_weight = torch.dropout(attn_weight, 0.00000000001, training)
382            attn_weight = attn_weight.to(torch.float16)
383            return attn_weight @ v
384
385        args = (
386            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
387            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
388            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
389        )
390        self._check_common(
391            sfdp_pattern_7,
392            args,
393            contains=SM80OrLater,
394            has_dropout=True,
395            override_check_equal=True,
396            atol=2e-3,
397        )
398
399        args = (
400            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
401            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
402            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
403        )
404        self._check_common(
405            checkpoint_wrapper(sfdp_pattern_7),
406            args,
407            contains=SM80OrLater,
408            has_dropout=True,
409            override_check_equal=True,
410            atol=2e-3,
411        )
412
413    @skipIfRocm
414    def _test_sdpa_rewriter_8(self):
415        def sfdp_pattern_8(query, key, value):
416            q = query.permute(0, 2, 1, 3)
417            k = key.permute(0, 2, 1, 3)
418            v = value.permute(0, 2, 1, 3)
419            div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
420            div = div.to(torch.float32)
421            attn_weight = torch.softmax(div, dim=-1)
422            attn_weight = attn_weight.to(torch.float16)
423            return attn_weight @ v
424
425        args = (
426            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
427            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
428            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
429        )
430        self._check_common(sfdp_pattern_8, args, atol=2e-3)
431
432        args = (
433            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
434            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
435            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
436        )
437        self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3)
438
439    @skipIfRocm
440    def _test_sdpa_rewriter_9(self):
441        def sfdp_pattern_9(query, key, value, training):
442            q = query.permute(0, 2, 1, 3)
443            k = key.permute(0, 2, 1, 3)
444            v = value.permute(0, 2, 1, 3)
445            q = q / math.sqrt(q.size(-1))
446            div = q @ k.transpose(-2, -1)
447            div = div.to(torch.float32)
448            attn_weight = torch.softmax(div, dim=-1)
449            # very low dropout to make test pass
450            attn_weight = torch.dropout(attn_weight, 0.00000000001, training)
451            attn_weight = attn_weight.to(torch.float16)
452            return attn_weight @ v
453
454        args = (
455            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
456            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
457            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
458        )
459        self._check_common(
460            sfdp_pattern_9,
461            args,
462            contains=SM80OrLater,
463            has_dropout=True,
464            override_check_equal=True,
465            atol=2e-3,
466        )
467        args = (
468            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
469            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
470            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
471        )
472        self._check_common(
473            checkpoint_wrapper(sfdp_pattern_9),
474            args,
475            contains=SM80OrLater,
476            has_dropout=True,
477            override_check_equal=True,
478            atol=2e-3,
479        )
480
481    @skipIfRocm
482    def _test_sdpa_rewriter_10(self):
483        def sfdp_pattern_10(query, key, value):
484            q = query.permute(0, 2, 1, 3)
485            k = key.permute(0, 2, 1, 3)
486            v = value.permute(0, 2, 1, 3)
487            q = q / math.sqrt(q.size(-1))
488            div = q @ k.transpose(-2, -1)
489            div = div.to(torch.float32)
490            attn_weight = torch.softmax(div, dim=-1)
491            attn_weight = attn_weight.to(torch.float16)
492            return attn_weight @ v
493
494        args = (
495            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
496            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
497            torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half),
498        )
499        self._check_common(sfdp_pattern_10, args, atol=2e-3)
500
501        args = (
502            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
503            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
504            torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half),
505        )
506        self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3)
507
508    def _test_pattern_fails_with_tensor_factor(self):
509        # https://github.com/pytorch/pytorch/issues/99124
510        class Model(torch.nn.Module):
511            def __init__(self, is_inv_factor):
512                super().__init__()
513                self.is_inv_factor = is_inv_factor
514
515            def forward(self, query, key, value, scale_factor) -> torch.Tensor:
516                # Dividing by scale_factor makes scale_factor gradients very
517                # unstable
518                scale_factor = scale_factor.detach()
519                y = torch.matmul(query, key.transpose(-2, -1))
520                if self.is_inv_factor:
521                    y = y.div(scale_factor)
522                else:
523                    y = y.mul(scale_factor)
524                return y.softmax(dim=-1).matmul(value)
525
526        tensor_shape = (2, 4, 4, 4)
527        for is_inv_factor in [True, False]:
528            args = [
529                torch.randn(tensor_shape, device=self.device),
530                torch.randn(tensor_shape, device=self.device),
531                torch.randn(tensor_shape, device=self.device),
532                torch.randn((4, 1, 1), device=self.device),
533            ]
534            model = Model(is_inv_factor).eval()
535            # The training path has an accuracy gap compared with eager mode.
536            self._check_common(
537                model, args1=args, contains=False, atol=1e-3, has_fuse_pattern=False
538            )
539
540    def _test_pattern_fails_with_unsupported_mask(self):
541        if not self.use_static_shapes:
542            self.skipTest("Causes shape specialization. TODO: investigate")
543
544        # https://github.com/pytorch/pytorch/issues/100315
545        class Model(torch.nn.Module):
546            def __init__(
547                self,
548            ):
549                super().__init__()
550
551            def forward(self, query, key, value, attn_mask) -> torch.Tensor:
552                attn_weight = torch.softmax(
553                    query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
554                    + attn_mask,
555                    dim=-1,
556                )
557                return attn_weight @ value
558
559        tensor_shape = (2, 4, 4, 4)
560
561        upsupported_masks = [
562            torch.randn((2, 4, 4, 4), device=self.device).to(dtype=torch.int),
563            2.0,
564        ]
565        for atte_mask in upsupported_masks:
566            args = [
567                torch.randn(tensor_shape, device=self.device),
568                torch.randn(tensor_shape, device=self.device),
569                torch.randn(tensor_shape, device=self.device),
570                atte_mask,
571            ]
572            model = Model().eval()
573            # The training path has an accuracy gap compared with eager mode.
574            self._check_common(
575                model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False
576            )
577
578    @skipIfRocm
579    def _test_sdpa_rewriter_11(self):
580        def dot_prod_attention(
581            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
582        ) -> torch.Tensor:
583            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
584            q = query.transpose(1, 2)
585            k = key.transpose(1, 2)
586            v = value.transpose(1, 2)
587            return (
588                torch.matmul(q, k.transpose(-2, -1))
589                .div(math.sqrt(key.shape[-1]))
590                .softmax(dim=-1)
591                .matmul(v)
592            )
593
594        self._check_common(dot_prod_attention)
595
596    def _test_sdpa_rewriter_12(self):
597        def dot_prod_attention(
598            query: torch.Tensor,
599            key: torch.Tensor,
600            value: torch.Tensor,
601            training: bool,
602        ) -> torch.Tensor:
603            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
604            q = query.transpose(1, 2)
605            k = key.transpose(1, 2)
606            v = value.transpose(1, 2)
607            return torch.nn.functional.dropout(
608                torch.matmul(q, k.transpose(-2, -1))
609                .div(math.sqrt(key.shape[-1]))
610                .softmax(dim=-1)
611                .matmul(v),
612                p=0.4,
613                training=training,
614                inplace=False,
615            )
616
617        self._check_common(dot_prod_attention, contains=False, has_dropout=True)
618
619    @skipIfRocm
620    def _test_sdpa_prev_13(self):
621        def dot_prod_attention(
622            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
623        ) -> torch.Tensor:
624            """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
625            return (
626                torch.matmul(query, key.transpose(-2, -1))
627                .div(math.sqrt(key.shape[-1]))
628                .softmax(dim=-1)
629                .clone()
630                .matmul(value)
631            )
632
633        self._check_common(dot_prod_attention, check_train=False)
634        self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
635
636    @skipIfRocm
637    def _test_sdpa_prev_14(self):
638        def dot_prod_attention(
639            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
640        ) -> torch.Tensor:
641            return (
642                torch.matmul(query, key.transpose(-2, -1))
643                .mul(1.0 / math.sqrt(key.shape[-1]))
644                .softmax(dim=-1)
645                .clone()
646                .matmul(value)
647            )
648
649        self._check_common(dot_prod_attention, check_train=False)
650        self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
651
652    @skipIfRocm
653    def _test_sdpa_prev_15(self):
654        def dot_prod_attention(
655            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
656        ) -> torch.Tensor:
657            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
658            q = query.transpose(1, 2)
659            k = key.transpose(1, 2)
660            v = value.transpose(1, 2)
661            return (
662                torch.matmul(q, k.transpose(-2, -1))
663                .div(math.sqrt(key.shape[-1]))
664                .softmax(dim=-1)
665                .clone()
666                .matmul(v)
667            )
668
669        self._check_common(dot_prod_attention, check_train=False)
670
671    @skipIfRocm
672    def _test_sdpa_rewriter_13(self, dtype):
673        def dot_prod_attention(
674            query: torch.Tensor,
675            key: torch.Tensor,
676            value: torch.Tensor,
677            training: bool,
678        ) -> torch.Tensor:
679            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
680            attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
681            attn_weight = torch.nn.functional.dropout(
682                attn_weight, p=0.5, training=training
683            )
684            return torch.bmm(attn_weight, value)
685
686        tensor_shape = (4, 8, 16)
687        args = [
688            torch.randn(tensor_shape, device=self.device, dtype=dtype),
689            torch.randn(tensor_shape, device=self.device, dtype=dtype),
690            torch.randn(tensor_shape, device=self.device, dtype=dtype),
691        ]
692
693        self._check_common(
694            dot_prod_attention,
695            check_train=False,
696            args1=args,
697            has_dropout=True,
698            override_check_equal=True,
699            atol=1e-2,
700            rtol=1e-2,
701        )
702
703    @skipIfRocm
704    def _test_sdpa_rewriter_14(self):
705        def dot_prod_attention(
706            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
707        ) -> torch.Tensor:
708            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
709            attn_mask = torch.ones(
710                query.size(1), key.size(1), dtype=torch.bool, device=query.device
711            ).tril(diagonal=0)
712            attn_mask = attn_mask.masked_fill(
713                torch.logical_not(attn_mask), -float("inf")
714            )
715            q = query.permute(0, 2, 1, 3)
716            k = key.permute(0, 2, 1, 3)
717            v = value.permute(0, 2, 1, 3)
718            return (
719                (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask)
720                .softmax(dim=-1)
721                .matmul(v)
722            )
723
724        self._check_common(dot_prod_attention)
725
726    @skipIfRocm
727    def _test_sdpa_rewriter_15(self):
728        def dot_prod_attention(
729            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
730        ) -> torch.Tensor:
731            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
732            q = query.transpose(1, 2)
733            k = key.transpose(1, 2)
734            v = value.transpose(1, 2)
735            bs = q.size(0)
736            k_len = k.size(-2)
737            attn_mask = torch.ones(
738                bs, k_len, dtype=torch.bool, device=query.device
739            ).tril(diagonal=0)
740            scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0
741            attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
742            scores = scores.masked_fill(attn_mask, -float("inf"))
743            weights = torch.nn.functional.softmax(scores, dim=-1)
744            return torch.matmul(weights, v)
745
746        self._check_common(dot_prod_attention, check_train=False)
747
748    @skipIfRocm
749    def _test_sdpa_rewriter_16(self):
750        def dot_prod_attention(
751            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training
752        ) -> torch.Tensor:
753            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
754            attn_mask = torch.ones(
755                query.size(1), key.size(1), dtype=torch.bool, device=query.device
756            ).tril(diagonal=0)
757            attn_mask = attn_mask.masked_fill(
758                torch.logical_not(attn_mask), -float("inf")
759            )
760            q = query.permute(0, 2, 1, 3)
761            k = key.permute(0, 2, 1, 3)
762            v = value.permute(0, 2, 1, 3)
763            return torch.nn.functional.dropout(
764                (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax(
765                    dim=-1
766                ),
767                p=0.4,
768                training=training,
769                inplace=False,
770            ).matmul(v)
771
772        self._check_common(dot_prod_attention, contains=False, has_dropout=True)
773
774        # also check batch_size=1 because the graph is slightly different
775        tensor_shape = (1, 2, 16, 32)
776        args = [
777            torch.randn(tensor_shape, device=self.device),
778            torch.randn(tensor_shape, device=self.device),
779            torch.randn(tensor_shape, device=self.device),
780        ]
781        self._check_common(
782            dot_prod_attention, args1=args, contains=False, has_dropout=True
783        )
784
785    @skipIfRocm
786    def _test_sdpa_rewriter_16_fp32_mask(self):
787        def dot_prod_attention(
788            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training
789        ) -> torch.Tensor:
790            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
791            attn_mask = torch.randn(
792                query.size(1), key.size(1), dtype=torch.float, device=query.device
793            ).tril(diagonal=0)
794            q = query.permute(0, 2, 1, 3)
795            k = key.permute(0, 2, 1, 3)
796            v = value.permute(0, 2, 1, 3)
797            return torch.nn.functional.dropout(
798                (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax(
799                    dim=-1
800                ),
801                p=0.4,
802                training=training,
803                inplace=False,
804            ).matmul(v)
805
806        self._check_common(dot_prod_attention, contains=False, has_dropout=True)
807
808        # also check batch_size=1 because the graph is slightly different
809        tensor_shape = (1, 2, 16, 32)
810        args = [
811            torch.randn(tensor_shape, device=self.device),
812            torch.randn(tensor_shape, device=self.device),
813            torch.randn(tensor_shape, device=self.device),
814        ]
815        self._check_common(
816            dot_prod_attention, args1=args, contains=False, has_dropout=True
817        )
818
819    @skipIfRocm
820    def _test_sdpa_rewriter_17(self):
821        def dot_prod_attention(
822            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training
823        ) -> torch.Tensor:
824            """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
825            q = query.transpose(1, 2)
826            k = key.transpose(1, 2)
827            v = value.transpose(1, 2)
828            bs = q.size(0)
829            k_len = k.size(-2)
830            attn_mask = torch.ones(
831                bs, k_len, dtype=torch.bool, device=query.device
832            ).tril(diagonal=0)
833            scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0
834            attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
835            scores = scores.masked_fill(attn_mask, -float("inf"))
836            weights = torch.nn.functional.softmax(scores, dim=-1)
837            weights = torch.nn.functional.dropout(
838                weights,
839                p=0.4,
840                training=training,
841                inplace=False,
842            )
843            return torch.matmul(weights, v)
844
845        self._check_common(dot_prod_attention, check_train=False, has_dropout=True)
846
847    @skipIfRocm
848    def _test_sdpa_rewriter_18(self):
849        def dot_prod_attention(
850            query: torch.Tensor,
851            key: torch.Tensor,
852            value: torch.Tensor,
853            causal_mask: torch.Tensor,
854        ) -> torch.Tensor:
855            # for hf_GPT2 with dropout
856            query = query.permute([0, 2, 1, 3])
857            key = key.permute([0, 2, 1, 3])
858            value = value.permute([0, 2, 1, 3])
859            attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
860            inv_scale = torch.full(
861                (), math.sqrt(value.size(-1)), dtype=query.dtype, device=query.device
862            )
863            attn_weights = attn_weights.div(inv_scale)
864            causal_mask_value = torch.full(
865                (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
866            )
867            attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
868            return (
869                (
870                    torch.nn.functional.dropout(
871                        attn_weights.softmax(dim=-1), 0.0
872                    ).matmul(value)
873                ),
874                key.permute([0, 2, 1, 3]),
875                value.permute([0, 2, 1, 3]),
876            )
877
878        tensor_shape = (4, 2, 16, 32)
879        causal_mask = torch.ones(2, 2, dtype=torch.bool, device=self.device).tril(
880            diagonal=0
881        )
882        args = [
883            torch.randn(tensor_shape, device=self.device),
884            torch.randn(tensor_shape, device=self.device),
885            torch.randn(tensor_shape, device=self.device),
886            causal_mask,
887        ]
888        self._check_common(
889            dot_prod_attention,
890            args1=args,
891            contains=False,
892            has_dropout=False,
893            check_train=False,
894        )
895
896        # also check batch_size=1 because the graph is slightly different
897        tensor_shape = (1, 2, 16, 32)
898        args = [
899            torch.randn(tensor_shape, device=self.device),
900            torch.randn(tensor_shape, device=self.device),
901            torch.randn(tensor_shape, device=self.device),
902            causal_mask,
903        ]
904        self._check_common(
905            dot_prod_attention,
906            args1=args,
907            contains=False,
908            has_dropout=False,
909            check_train=False,
910        )
911
912    @skipIfRocm
913    def _test_sdpa_rewriter_19(self):
914        def dot_prod_attention(
915            query: torch.Tensor,
916            key: torch.Tensor,
917            value: torch.Tensor,
918            causal_mask: torch.Tensor,
919            attn_mask: torch.Tensor,
920            training,
921        ) -> torch.Tensor:
922            attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
923            inv_scale = torch.full(
924                (),
925                math.sqrt(value.size(-1)),
926                dtype=attn_weights.dtype,
927                device=attn_weights.device,
928            )
929            attn_weights = attn_weights.div(inv_scale)
930            causal_mask_value = torch.full(
931                (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
932            )
933            attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
934            attn_weights = attn_weights + attn_mask
935            attn_weights = attn_weights.softmax(dim=-1).type(value.dtype)
936            return torch.nn.functional.dropout(
937                attn_weights,
938                p=0.4,
939                training=training,
940                inplace=False,
941            ).matmul(value)
942
943        tensor_shape = (4, 2, 16, 32)
944        causal_mask = torch.ones(16, 16, dtype=torch.bool, device=self.device).tril(
945            diagonal=0
946        )
947        attn_mask = torch.randn((16, 16), dtype=torch.float, device=self.device)
948        args = [
949            torch.randn(tensor_shape, device=self.device),
950            torch.randn(tensor_shape, device=self.device),
951            torch.randn(tensor_shape, device=self.device),
952            causal_mask,
953            attn_mask,
954        ]
955        self._check_common(
956            dot_prod_attention,
957            args1=args,
958            contains=False,
959            has_dropout=True,
960            check_train=False,
961        )
962
963
964if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
965
966    class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
967        device = "cuda"
968        test_sdpa_rewriter_1_cuda = (
969            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
970        )
971        test_sdpa_rewriter_1_freezing = (
972            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing
973        )
974        test_insignificant_strides = (
975            TestSDPAPatternRewriterTemplate._test_insignificant_strides
976        )
977        test_pattern_fails_with_reuse_cuda = (
978            TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse
979        )
980        test_sdpa_rewriter_2_cuda = (
981            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
982        )
983        test_sdpa_rewriter_3_cuda = (
984            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3
985        )
986        test_sdpa_rewriter_4_cuda = (
987            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4
988        )
989        test_sdpa_rewriter_5_cuda = (
990            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
991        )
992        test_sdpa_rewriter_6_cuda = (
993            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6
994        )
995        test_sdpa_rewriter_7_cuda = (
996            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7
997        )
998        test_sdpa_rewriter_8_cuda = (
999            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8
1000        )
1001        test_sdpa_rewriter_9_cuda = (
1002            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9
1003        )
1004        test_sdpa_rewriter_10_cuda = (
1005            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10
1006        )
1007        test_pattern_fails_with_tensor_factor_cuda = (
1008            TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor
1009        )
1010        test_pattern_fails_with_unsupported_mask_cuda = (
1011            TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask
1012        )
1013        test_sdpa_rewriter_11_cuda = (
1014            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11
1015        )
1016        test_sdpa_rewriter_12_cuda = (
1017            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
1018        )
1019        test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
1020        test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
1021        test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
1022        test_sdpa_rewriter_13_cuda = functools.partialmethod(
1023            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half
1024        )
1025        test_sdpa_rewriter_14_cuda = functools.partialmethod(
1026            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
1027        )
1028        test_sdpa_rewriter_15_cuda = functools.partialmethod(
1029            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
1030        )
1031        test_sdpa_rewriter_17_cuda = functools.partialmethod(
1032            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
1033        )
1034        test_sdpa_rewriter_19_cuda = functools.partialmethod(
1035            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
1036        )
1037
1038    class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests):
1039        use_static_shapes = False
1040
1041
1042if HAS_CPU:
1043
1044    class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
1045        device = "cpu"
1046        test_sdpa_rewriter_1_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1
1047        test_pattern_fails_with_reuse_cpu = (
1048            TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse
1049        )
1050        test_sdpa_rewriter_2_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2
1051        test_sdpa_rewriter_5_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5
1052        test_pattern_fails_with_tensor_factor_cpu = (
1053            TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor
1054        )
1055        test_pattern_fails_with_unsupported_mask_cpu = (
1056            TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask
1057        )
1058        test_sdpa_rewriter_11_cpu = (
1059            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11
1060        )
1061        test_sdpa_rewriter_12_cpu = (
1062            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
1063        )
1064        test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13
1065        test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14
1066        test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15
1067        test_sdpa_rewriter_13_cpu = functools.partialmethod(
1068            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32
1069        )
1070        test_sdpa_rewriter_14_cpu = functools.partialmethod(
1071            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
1072        )
1073        test_sdpa_rewriter_15_cpu = functools.partialmethod(
1074            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
1075        )
1076        test_sdpa_rewriter_16_cpu = functools.partialmethod(
1077            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16
1078        )
1079        test_sdpa_rewriter_16_fp32_mask_cpu = functools.partialmethod(
1080            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16_fp32_mask
1081        )
1082        test_sdpa_rewriter_17_cpu = functools.partialmethod(
1083            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17
1084        )
1085        test_sdpa_rewriter_18_cpu = functools.partialmethod(
1086            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_18
1087        )
1088        test_sdpa_rewriter_19_cpu = functools.partialmethod(
1089            TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
1090        )
1091
1092    class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests):
1093        use_static_shapes = False
1094
1095
1096if __name__ == "__main__":
1097    if IS_LINUX:
1098        run_tests()
1099