1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import os 8import tempfile 9import unittest 10 11import torch 12from executorch.exir import EdgeCompileConfig, to_edge 13 14from executorch.extension.llm.modules.attention import ( 15 MultiHeadAttention as ETMultiHeadAttention, 16) 17from executorch.runtime import Runtime 18from torch._inductor.package import load_package, package_aoti 19from torch.testing import assert_close 20from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE 21from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention 22 23 24class AttentionTest(unittest.TestCase): 25 def setUp(self): 26 super().setUp() 27 torch.manual_seed(0) 28 # Constants 29 self.embed_dim = 2048 30 self.num_heads = 8 31 self.num_kv_heads = 8 32 self.head_dim = 64 33 self.max_seq_len = 128 34 self.rope_base = 500_000 35 self.scale_factor = 32 36 37 # Module dependency injections. 38 self.q_proj = torch.nn.Linear( 39 self.embed_dim, self.num_heads * self.head_dim, bias=False 40 ) 41 self.k_proj = torch.nn.Linear( 42 self.embed_dim, self.num_kv_heads * self.head_dim, bias=False 43 ) 44 self.k_proj.weight.requires_grad = False 45 self.v_proj = torch.nn.Linear( 46 self.embed_dim, self.num_kv_heads * self.head_dim, bias=False 47 ) 48 self.v_proj.weight.requires_grad = False 49 self.output_proj = torch.nn.Linear( 50 self.num_heads * self.head_dim, self.embed_dim, bias=False 51 ) 52 self.pos_embeddings = Llama3ScaledRoPE( 53 dim=self.head_dim, 54 max_seq_len=self.max_seq_len, 55 base=self.rope_base, 56 scale_factor=self.scale_factor, 57 ) 58 59 # Original TorchTune reference module to test accuracy against. 60 self.tt_mha = TTMultiHeadAttention( 61 embed_dim=self.embed_dim, 62 num_heads=self.num_heads, 63 num_kv_heads=self.num_kv_heads, 64 head_dim=self.head_dim, 65 q_proj=self.q_proj, 66 k_proj=self.k_proj, 67 v_proj=self.v_proj, 68 output_proj=self.output_proj, 69 pos_embeddings=self.pos_embeddings, 70 max_seq_len=self.max_seq_len, 71 ) 72 73 # Source transformed module that we are testing. 74 self.et_mha = ETMultiHeadAttention( 75 embed_dim=self.embed_dim, 76 num_heads=self.num_heads, 77 num_kv_heads=self.num_kv_heads, 78 head_dim=self.head_dim, 79 q_proj=self.q_proj, 80 k_proj=self.k_proj, 81 v_proj=self.v_proj, 82 output_proj=self.output_proj, 83 pos_embeddings=self.pos_embeddings, 84 max_seq_len=self.max_seq_len, 85 ) 86 self.et_mha.load_state_dict(self.tt_mha.state_dict()) 87 # Common inputs. 88 seq_len = 10 89 self.x = torch.randn(1, seq_len, self.embed_dim) 90 self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len] 91 seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) 92 self.dynamic_shapes = ( 93 {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, 94 {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, 95 {0: torch.export.Dim.STATIC, 1: seq_len_dim}, 96 ) 97 self.causal_mask = torch.tril( 98 torch.ones( 99 size=(self.max_seq_len, self.max_seq_len), 100 dtype=torch.bool, 101 ) 102 ) 103 104 def test_attention_eager(self): 105 et_res = self.et_mha(self.x, self.x) # Self attention. 106 tt_res = self.tt_mha(self.x, self.x) # Self attention. 107 108 assert_close(et_res, tt_res) 109 110 # test with kv cache 111 self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) 112 self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) 113 114 et_res = self.et_mha(self.x, self.x) # Self attention. 115 tt_res = self.tt_mha(self.x, self.x) # Self attention. 116 117 self.assertTrue(torch.allclose(et_res, tt_res)) 118 self.et_mha.reset_cache() 119 self.tt_mha.reset_cache() 120 121 et_res = self.et_mha( 122 self.x, self.x, input_pos=self.input_pos 123 ) # Self attention with input pos. 124 tt_res = self.tt_mha( 125 self.x, self.x, input_pos=self.input_pos 126 ) # Self attention with input pos. 127 128 self.assertTrue(torch.allclose(et_res, tt_res)) 129 130 # test kv cache read. Input pos can be [10, 11, ..., 19] 131 next_input_pos = torch.arange(10, 20).unsqueeze(0) 132 et_res = self.et_mha( 133 self.x, self.x, input_pos=next_input_pos 134 ) # Self attention with input pos. 135 tt_res = self.tt_mha( 136 self.x, self.x, input_pos=next_input_pos 137 ) # Self attention with input pos. 138 139 assert_close(et_res, tt_res) 140 141 def test_attention_export(self): 142 # Self attention. 143 144 # test with kv cache 145 self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 146 self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 147 with torch.no_grad(): 148 et_mha_ep = torch.export.export( 149 self.et_mha, 150 (self.x, self.x), 151 kwargs={"input_pos": self.input_pos}, 152 dynamic_shapes=self.dynamic_shapes, 153 ) 154 et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) 155 tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) 156 157 assert_close(et_res, tt_res) 158 159 @unittest.skipIf( 160 int(os.getenv("RUN_SKIPPED", 0)) < 1, reason="TODO(T207740932): test is flaky" 161 ) 162 def test_attention_aoti(self): 163 # Self attention. 164 165 # test with kv cache 166 self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 167 self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 168 with torch.no_grad(): 169 so = torch._export.aot_compile( 170 self.et_mha, 171 args=(self.x, self.x), 172 kwargs={"input_pos": self.input_pos}, 173 options={ 174 "aot_inductor.package": True, 175 "reorder_for_peak_memory": False, 176 }, 177 dynamic_shapes=self.dynamic_shapes, 178 ) 179 with tempfile.TemporaryDirectory() as tempdir: 180 path = package_aoti(os.path.join(tempdir, "mha.pt2"), so) 181 mha_aoti = load_package(path) 182 183 aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos) 184 tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) 185 assert_close(aoti_res, tt_res) 186 187 def test_attention_executorch(self): 188 # Self attention. 189 # TODO: Fix kv cache 190 # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 191 # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 192 193 with torch.no_grad(): 194 et_mha_ep = torch.export.export( 195 self.et_mha, 196 (self.x, self.x), 197 kwargs={"input_pos": self.input_pos}, 198 dynamic_shapes=self.dynamic_shapes, 199 ) 200 et_program = to_edge( 201 et_mha_ep, 202 compile_config=EdgeCompileConfig( 203 _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg] 204 ), 205 ).to_executorch() 206 runtime = Runtime.get() 207 program = runtime.load_program(et_program.buffer) 208 method = program.load_method("forward") 209 et_res = method.execute((self.x, self.x, self.input_pos)) 210 tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) 211 212 assert_close(et_res[0], tt_res) 213 214 def test_attention_torch_cond_eager(self): 215 # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition. 216 # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan. 217 self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) 218 self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) 219 220 # mask 221 mask = self.causal_mask[self.input_pos, :] 222 # First run 223 et_res = self.et_mha( 224 self.x, self.x, mask=mask, input_pos=self.input_pos 225 ) # Self attention with input pos. 226 tt_res = self.tt_mha( 227 self.x, self.x, mask=mask, input_pos=self.input_pos 228 ) # Self attention with input pos. 229 230 self.assertTrue(torch.allclose(et_res, tt_res)) 231 232 # Second run test kv cache read. Input pos is [10, 11, ..., 19] 233 next_input_pos = torch.arange(10, 20).unsqueeze(0) 234 235 empty_y = torch.full_like(self.x, torch.nan) 236 mask = self.causal_mask[next_input_pos, :] 237 et_res = self.et_mha( 238 self.x, empty_y, mask=mask, input_pos=next_input_pos 239 ) # Self attention with input pos. 240 tt_res = self.tt_mha( 241 self.x, None, mask=mask, input_pos=next_input_pos 242 ) # Self attention with input pos. 243 244 assert_close(et_res, tt_res) 245