xref: /aosp_15_r20/external/pytorch/test/test_native_mha.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import math
3import copy
4
5import torch
6from torch.testing._internal.common_device_type import (
7    dtypes,
8    dtypesIfCUDA,
9    instantiate_device_type_tests,
10    onlyCUDA,
11    skipMeta,
12)
13from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_WITH_ROCM
14
15class TestMHADeviceType(TestCase):
16    @torch.no_grad()
17    def _test_transform_bias_rescale_qkv_impl(
18        self, device, dtype, use_nt, use_padding=False
19    ):
20        tests = [
21            (64, 4, 16, 8),
22            # dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
23            (24, 2, 4, 2),
24            # Make sure CUDA can handle small input sizes
25            (2, 2, 2, 2),
26            # dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4,
27            # causes alignment issues
28            (24, 4, 4, 2),
29            (48, 4, 16, 8),
30        ]
31        for (embed_dim, num_heads, bs, sl) in tests:
32            with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl):
33                torch.manual_seed(9343)
34                dense_x = x = (
35                    torch.randn(bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10
36                )
37                if use_padding:
38                    x[0][-1] = torch.full(x[0][-1].shape, float("-Inf"))
39                if use_nt:
40                    xs = list(torch.unbind(x))
41                    if use_padding:
42                        xs[0] = xs[0][:-1]
43                    x = torch.nested.nested_tensor(xs, device=device, dtype=dtype)
44                qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
45
46                # We have to use inference_mode here because q/k/v are
47                # all views of the same Tensor, which autograd doesn't
48                # like. This is fine because this function is only
49                # exposed to Python for purposes of writing this test.
50                with torch.inference_mode():
51                    (q, k, v) = torch._transform_bias_rescale_qkv(
52                        x, qkv.bias, num_heads=num_heads
53                    )
54
55                    def simple_transform_bias_rescale_qkv(qkv, bias):
56                        (q, k, v) = torch.split(qkv, embed_dim, dim=-1)
57                        (q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
58
59                        def embiggen(x):
60                            if not use_nt:
61                                return x
62                            b, t, d = x.size()
63                            t = t + (8 - t % 8) % 8
64                            newsize = (b, t, d)
65                            new_x = torch.zeros(newsize, device=device, dtype=dtype)
66                            new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x
67                            return new_x
68                        return tuple(
69                            embiggen(x).reshape(
70                                (bs, -1, num_heads, embed_dim // num_heads)
71                            ).transpose(2, 1)
72                            for x in (
73                                (q + q_bias) / math.sqrt(embed_dim // num_heads),
74                                (k + k_bias),
75                                (v + v_bias),
76                            )
77                        )
78
79                    correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(
80                        dense_x, qkv.bias
81                    )
82                    if use_nt and use_padding:
83                        for t in (correct_q, correct_k, correct_v):
84                            t[t == float("-Inf")] = 0
85
86                self.assertEqual(q.size(), correct_q.size())
87                torch.testing.assert_close(q, correct_q)
88                torch.testing.assert_close(k, correct_k)
89                torch.testing.assert_close(v, correct_v)
90
91    @dtypesIfCUDA(torch.float)
92    @dtypes(torch.float)
93    @skipMeta
94    def test_transform_bias_rescale_qkv(self, device, dtype):
95        for use_padding in (False, True):
96            with self.subTest(use_padding=use_padding):
97                self._test_transform_bias_rescale_qkv_impl(
98                    device, dtype, use_nt=False, use_padding=use_padding
99                )
100
101    @dtypesIfCUDA(torch.float)
102    @dtypes(torch.float)
103    @skipMeta
104    @onlyCUDA
105    def test_transform_bias_rescale_qkv_nested(self, device, dtype):
106        for use_padding in (False, True):
107            with self.subTest(use_padding=use_padding):
108                self._test_transform_bias_rescale_qkv_impl(
109                    device, dtype, use_nt=True, use_padding=use_padding
110                )
111
112    def _test_multihead_attention_impl(
113        self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False
114    ):
115        embed_dim = 64
116        num_heads = 4
117        bs = 16
118        sl = 8
119
120        q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
121        if use_padding:
122            if pad_all:
123                for q_i in q:
124                    q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32)
125                mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
126                for mask_i in mask:
127                    mask_i[-1] = True
128            else:
129                q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32)
130                mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
131                mask[0][-1] = True
132        if mode == "self":
133            k = q
134            v = q
135        elif mode == "encdec":
136            k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
137            v = k
138        elif mode == "generic":
139            k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
140            v = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
141        else:
142            self.fail(f"invalid mode `{mode}`!")
143
144        qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=torch.float32)
145        native_qkv = copy.deepcopy(qkv).to(dtype=dtype)
146
147        proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
148        native_proj = copy.deepcopy(proj).to(dtype=dtype)
149
150        pt = torch.nn.MultiheadAttention(
151            embed_dim, num_heads, batch_first=True, device=device, dtype=torch.float32
152        )
153
154        pt.in_proj_weight = qkv.weight
155        pt.in_proj_bias = qkv.bias
156        pt.out_proj.weight = proj.weight
157        pt.out_proj.bias = proj.bias
158
159        class NativeMHA(torch.nn.Module):
160            def __init__(self, embed_dim, num_heads, qkv, proj):
161                super().__init__()
162                self.qkv = qkv
163                self.proj = proj
164                self.embed_dim = embed_dim
165                self.num_heads = num_heads
166
167            def forward(self, q, k, v, key_padding_mask):
168                return torch._native_multi_head_attention(
169                    q,
170                    k,
171                    v,
172                    self.embed_dim,
173                    self.num_heads,
174                    self.qkv.weight,
175                    self.qkv.bias,
176                    self.proj.weight,
177                    self.proj.bias,
178                    key_padding_mask,
179                    need_weights=need_weights,
180                    average_attn_weights=average_attn_weights,
181                    mask_type=1,   # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
182                )
183
184        npt = NativeMHA(
185            embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj
186        ).to(dtype)
187
188        if device == "cuda":
189            pt = pt.cuda()
190            npt = npt.cuda()
191
192        ypt, weight_pt = pt(
193            q,
194            k,
195            v,
196            need_weights=need_weights,
197            average_attn_weights=average_attn_weights,
198            key_padding_mask=mask if use_padding else None,
199        )
200        if use_nt:
201            qs = list(torch.unbind(q))
202            if use_padding:
203                if pad_all:
204                    qs = [x[:-1] for x in qs]
205                else:
206                    qs[0] = qs[0][:-1]
207            q = torch.nested.nested_tensor(qs, device=device, dtype=dtype)
208            if mode == "self":
209                k = v = q
210            elif mode == "encdec":
211                k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
212                v = k
213            else:
214                k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
215                v = torch.nested.nested_tensor(torch.unbind(v), device=device, dtype=dtype)
216
217        native_q = q.to(dtype=dtype)
218        native_k = k.to(dtype=dtype)
219        native_v = v.to(dtype=dtype)
220
221        ynpt, weight_npt = npt(
222            native_q, native_k, native_v, key_padding_mask=mask if use_padding and not use_nt else None
223        )
224        if use_nt:
225            ynpt = ynpt.to_padded_tensor(0)
226            if pad_all:
227                ynpt_final = torch.zeros_like(ypt)
228                ynpt_final[:, :ynpt.shape[1], :] = ynpt
229                ynpt = ynpt_final
230
231        def do_pad_all(tensors):
232            for t in tensors:
233                for t_i in t:
234                    t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype)
235
236        # PyTorch implementation returns non-zero junk in the padding
237        # locations; overwrite it so that the comparison works out.
238        if use_padding:
239            ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype)
240            ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype)
241            if pad_all:
242                do_pad_all((ypt, ynpt))
243            # Zero the last row of each TxT weight matrix
244            if need_weights:
245                if average_attn_weights:
246                    weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype)
247                    weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype)
248                    if pad_all:
249                        do_pad_all((weight_pt, weight_npt))
250                else:
251                    for nh in range(num_heads):
252                        weight_pt[0][nh][-1] = torch.zeros_like(weight_pt[0][nh][-1], device=device, dtype=dtype)
253                        weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype)
254
255        if dtype == torch.half:
256            torch.testing.assert_close(ypt, ynpt.to(torch.float32), atol=1e-3, rtol=1e-3)
257        else:
258            # High rtol seems necessary for
259            # test_native_multihead_attention_cpu_float32 on Windows,
260            # otherwise 2e-4 would likely be fine.
261            torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3)
262
263        if need_weights:
264            torch.testing.assert_close(weight_pt, weight_npt.to(torch.float32), atol=5e-4, rtol=5e-4)
265        else:
266            self.assertEqual(weight_pt, weight_npt)
267
268    @dtypesIfCUDA(torch.float, torch.half)
269    @dtypes(torch.float)
270    @skipMeta
271    @parametrize("use_nt", [False, True])
272    @parametrize("use_padding, pad_all", [(False, False), (True, False), (True, True)])
273    @parametrize("need_weights", [False])
274    @parametrize("average_attn_weights", [False, True])
275    @parametrize("fused", [False, True])
276    @torch.no_grad()
277    def test_native_multihead_self_attention(self, device, dtype, use_nt,
278                                             need_weights, average_attn_weights, use_padding, pad_all, fused):
279        if TEST_WITH_ROCM:
280            if use_nt:
281                self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
282            if use_padding and not pad_all and fused:
283                self.skipTest("Large numerical errors on ROCM to investigate.")
284        for need_weights in (False, not pad_all):
285            with self.subTest(use_padding=use_padding, pad_all=pad_all,
286                              use_nt=use_nt, need_weights=need_weights,
287                              average_attn_weights=average_attn_weights):
288                with torch.backends.cuda.sdp_kernel(
289                        enable_flash=False, enable_mem_efficient=False
290                ) if not fused else torch.backends.cuda.sdp_kernel(
291                        enable_flash=True, enable_mem_efficient=True
292                ):
293                    self._test_multihead_attention_impl(
294                        device,
295                        dtype,
296                        "self",
297                        use_nt=use_nt,
298                        use_padding=use_padding,
299                        pad_all=pad_all,
300                        need_weights=need_weights,
301                        average_attn_weights=average_attn_weights,
302                    )
303
304    @dtypesIfCUDA(torch.float, torch.half)
305    @dtypes(torch.float)
306    @skipMeta
307    @torch.no_grad()
308    def test_native_multihead_encoder_decoder_attention(self, device, dtype):
309        self._test_multihead_attention_impl(
310            device,
311            dtype,
312            "encdec",
313            use_nt=False,
314            need_weights=False,
315            average_attn_weights=False,
316        )
317
318    @dtypesIfCUDA(torch.float, torch.half)
319    @dtypes(torch.float)
320    @skipMeta
321    @torch.no_grad()
322    def test_native_multihead_attention(self, device, dtype):
323        self._test_multihead_attention_impl(
324            device,
325            dtype,
326            "generic",
327            use_nt=False,
328            need_weights=False,
329            average_attn_weights=False,
330        )
331
332
333instantiate_device_type_tests(TestMHADeviceType, globals())
334
335if __name__ == "__main__":
336    run_tests()
337