xref: /aosp_15_r20/external/executorch/extension/llm/modules/test/test_attention.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import os
8import tempfile
9import unittest
10
11import torch
12from executorch.exir import EdgeCompileConfig, to_edge
13
14from executorch.extension.llm.modules.attention import (
15    MultiHeadAttention as ETMultiHeadAttention,
16)
17from executorch.runtime import Runtime
18from torch._inductor.package import load_package, package_aoti
19from torch.testing import assert_close
20from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
21from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
22
23
24class AttentionTest(unittest.TestCase):
25    def setUp(self):
26        super().setUp()
27        torch.manual_seed(0)
28        # Constants
29        self.embed_dim = 2048
30        self.num_heads = 8
31        self.num_kv_heads = 8
32        self.head_dim = 64
33        self.max_seq_len = 128
34        self.rope_base = 500_000
35        self.scale_factor = 32
36
37        # Module dependency injections.
38        self.q_proj = torch.nn.Linear(
39            self.embed_dim, self.num_heads * self.head_dim, bias=False
40        )
41        self.k_proj = torch.nn.Linear(
42            self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
43        )
44        self.k_proj.weight.requires_grad = False
45        self.v_proj = torch.nn.Linear(
46            self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
47        )
48        self.v_proj.weight.requires_grad = False
49        self.output_proj = torch.nn.Linear(
50            self.num_heads * self.head_dim, self.embed_dim, bias=False
51        )
52        self.pos_embeddings = Llama3ScaledRoPE(
53            dim=self.head_dim,
54            max_seq_len=self.max_seq_len,
55            base=self.rope_base,
56            scale_factor=self.scale_factor,
57        )
58
59        # Original TorchTune reference module to test accuracy against.
60        self.tt_mha = TTMultiHeadAttention(
61            embed_dim=self.embed_dim,
62            num_heads=self.num_heads,
63            num_kv_heads=self.num_kv_heads,
64            head_dim=self.head_dim,
65            q_proj=self.q_proj,
66            k_proj=self.k_proj,
67            v_proj=self.v_proj,
68            output_proj=self.output_proj,
69            pos_embeddings=self.pos_embeddings,
70            max_seq_len=self.max_seq_len,
71        )
72
73        # Source transformed module that we are testing.
74        self.et_mha = ETMultiHeadAttention(
75            embed_dim=self.embed_dim,
76            num_heads=self.num_heads,
77            num_kv_heads=self.num_kv_heads,
78            head_dim=self.head_dim,
79            q_proj=self.q_proj,
80            k_proj=self.k_proj,
81            v_proj=self.v_proj,
82            output_proj=self.output_proj,
83            pos_embeddings=self.pos_embeddings,
84            max_seq_len=self.max_seq_len,
85        )
86        self.et_mha.load_state_dict(self.tt_mha.state_dict())
87        # Common inputs.
88        seq_len = 10
89        self.x = torch.randn(1, seq_len, self.embed_dim)
90        self.input_pos = torch.arange(seq_len).unsqueeze(0)  # shape [1, seq_len]
91        seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
92        self.dynamic_shapes = (
93            {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
94            {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
95            {0: torch.export.Dim.STATIC, 1: seq_len_dim},
96        )
97        self.causal_mask = torch.tril(
98            torch.ones(
99                size=(self.max_seq_len, self.max_seq_len),
100                dtype=torch.bool,
101            )
102        )
103
104    def test_attention_eager(self):
105        et_res = self.et_mha(self.x, self.x)  # Self attention.
106        tt_res = self.tt_mha(self.x, self.x)  # Self attention.
107
108        assert_close(et_res, tt_res)
109
110        # test with kv cache
111        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
112        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
113
114        et_res = self.et_mha(self.x, self.x)  # Self attention.
115        tt_res = self.tt_mha(self.x, self.x)  # Self attention.
116
117        self.assertTrue(torch.allclose(et_res, tt_res))
118        self.et_mha.reset_cache()
119        self.tt_mha.reset_cache()
120
121        et_res = self.et_mha(
122            self.x, self.x, input_pos=self.input_pos
123        )  # Self attention with input pos.
124        tt_res = self.tt_mha(
125            self.x, self.x, input_pos=self.input_pos
126        )  # Self attention with input pos.
127
128        self.assertTrue(torch.allclose(et_res, tt_res))
129
130        # test kv cache read. Input pos can be [10, 11, ..., 19]
131        next_input_pos = torch.arange(10, 20).unsqueeze(0)
132        et_res = self.et_mha(
133            self.x, self.x, input_pos=next_input_pos
134        )  # Self attention with input pos.
135        tt_res = self.tt_mha(
136            self.x, self.x, input_pos=next_input_pos
137        )  # Self attention with input pos.
138
139        assert_close(et_res, tt_res)
140
141    def test_attention_export(self):
142        # Self attention.
143
144        # test with kv cache
145        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
146        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
147        with torch.no_grad():
148            et_mha_ep = torch.export.export(
149                self.et_mha,
150                (self.x, self.x),
151                kwargs={"input_pos": self.input_pos},
152                dynamic_shapes=self.dynamic_shapes,
153            )
154        et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
155        tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
156
157        assert_close(et_res, tt_res)
158
159    @unittest.skipIf(
160        int(os.getenv("RUN_SKIPPED", 0)) < 1, reason="TODO(T207740932): test is flaky"
161    )
162    def test_attention_aoti(self):
163        # Self attention.
164
165        # test with kv cache
166        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
167        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
168        with torch.no_grad():
169            so = torch._export.aot_compile(
170                self.et_mha,
171                args=(self.x, self.x),
172                kwargs={"input_pos": self.input_pos},
173                options={
174                    "aot_inductor.package": True,
175                    "reorder_for_peak_memory": False,
176                },
177                dynamic_shapes=self.dynamic_shapes,
178            )
179        with tempfile.TemporaryDirectory() as tempdir:
180            path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
181            mha_aoti = load_package(path)
182
183            aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
184            tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
185            assert_close(aoti_res, tt_res)
186
187    def test_attention_executorch(self):
188        # Self attention.
189        # TODO: Fix kv cache
190        # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
191        # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192
193        with torch.no_grad():
194            et_mha_ep = torch.export.export(
195                self.et_mha,
196                (self.x, self.x),
197                kwargs={"input_pos": self.input_pos},
198                dynamic_shapes=self.dynamic_shapes,
199            )
200        et_program = to_edge(
201            et_mha_ep,
202            compile_config=EdgeCompileConfig(
203                _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
204            ),
205        ).to_executorch()
206        runtime = Runtime.get()
207        program = runtime.load_program(et_program.buffer)
208        method = program.load_method("forward")
209        et_res = method.execute((self.x, self.x, self.input_pos))
210        tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
211
212        assert_close(et_res[0], tt_res)
213
214    def test_attention_torch_cond_eager(self):
215        # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
216        # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
217        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
218        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
219
220        # mask
221        mask = self.causal_mask[self.input_pos, :]
222        # First run
223        et_res = self.et_mha(
224            self.x, self.x, mask=mask, input_pos=self.input_pos
225        )  # Self attention with input pos.
226        tt_res = self.tt_mha(
227            self.x, self.x, mask=mask, input_pos=self.input_pos
228        )  # Self attention with input pos.
229
230        self.assertTrue(torch.allclose(et_res, tt_res))
231
232        # Second run test kv cache read. Input pos is [10, 11, ..., 19]
233        next_input_pos = torch.arange(10, 20).unsqueeze(0)
234
235        empty_y = torch.full_like(self.x, torch.nan)
236        mask = self.causal_mask[next_input_pos, :]
237        et_res = self.et_mha(
238            self.x, empty_y, mask=mask, input_pos=next_input_pos
239        )  # Self attention with input pos.
240        tt_res = self.tt_mha(
241            self.x, None, mask=mask, input_pos=next_input_pos
242        )  # Self attention with input pos.
243
244        assert_close(et_res, tt_res)
245