1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9# Example script for exporting Llama2 to flatbuffer 10 11import math 12from typing import Tuple, Union 13 14import torch 15 16from executorch.examples.models.llama.llama_transformer import KVCache, SDPA 17from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( 18 QuantizedKVCache, 19) 20 21 22class SDPACustom(torch.nn.Module): 23 def __init__( 24 self, 25 kv_cache: Union[KVCache, QuantizedKVCache], 26 dim: int, 27 ): 28 super().__init__() 29 # Custom op only supports float32 currently. Converting to/from float32 is 30 # faster than not having the op. 31 self.kv_cache = kv_cache 32 if not isinstance(kv_cache, QuantizedKVCache): 33 self.kv_cache = kv_cache.to(torch.float) 34 else: 35 assert ( 36 kv_cache.cache_fp_type == torch.float32 37 ), "Only float32 is supported for custom SDPA" 38 self.dim = dim 39 40 def forward( 41 self, 42 input_pos: torch.Tensor, 43 q: torch.Tensor, 44 k: torch.Tensor, 45 v: torch.Tensor, 46 bsz, 47 seqlen, 48 mask, 49 ): 50 # Custom op only supports float32 currently. Converting to/from float32 is 51 # faster than not having the op. 52 input_dtype = q.dtype 53 q = q.to(dtype=torch.float) 54 k = k.to(dtype=torch.float) 55 v = v.to(dtype=torch.float) 56 57 k_cache = self.kv_cache.k_cache 58 v_cache = self.kv_cache.v_cache 59 if isinstance(self.kv_cache, QuantizedKVCache): 60 # updated quantize cache, scale and zero points 61 # returns dequantized kv cache 62 # Not most optimal. Optimizations to follow next 63 k_cache, v_cache = self.kv_cache.update(input_pos, k, v) 64 output = torch.ops.llama.custom_sdpa( 65 q, 66 k_cache, 67 v_cache, 68 input_pos[0].item(), 69 None, # Attention mask 70 0, # dropout probability. Ignored by the code 71 True, # is_causal 72 ) 73 else: 74 output = torch.ops.llama.sdpa_with_kv_cache( 75 q, 76 k, 77 v, 78 k_cache, 79 v_cache, 80 input_pos[0].item(), 81 seqlen, 82 None, # Attention mask 83 0, # dropout probability. Ignored by the code 84 True, # is_causal 85 ) 86 return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) 87 88 89def _replace_sdpa_with_custom_op(module: torch.nn.Module): 90 for name, child in module.named_children(): 91 if isinstance(child, SDPA): 92 setattr( 93 module, 94 name, 95 SDPACustom(child.kv_cache, child.dim), 96 ) 97 else: 98 _replace_sdpa_with_custom_op(child) 99 100 101def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: 102 from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa 103 104 _replace_sdpa_with_custom_op(module) 105 return module 106 107 108class SDPASimple(torch.nn.Module): 109 110 def __init__( 111 self, 112 kv_cache: KVCache, 113 dim: int, 114 head_dim: int, 115 n_rep: int, 116 ): 117 super().__init__() 118 self.kv_cache = kv_cache 119 self.dim = dim 120 self.head_dim = head_dim 121 self.n_rep = n_rep 122 123 def forward( 124 self, 125 input_pos: torch.Tensor, 126 q: torch.Tensor, 127 k: torch.Tensor, 128 v: torch.Tensor, 129 bsz, 130 seqlen, 131 mask, 132 ): 133 q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 134 k = k.transpose(1, 2) 135 v = v.transpose(1, 2) 136 137 k, v = self.kv_cache.update(input_pos, k, v) 138 attn_mask = mask[None, None, input_pos] 139 140 k = k.repeat_interleave(self.n_rep, dim=1) 141 v = v.repeat_interleave(self.n_rep, dim=1) 142 scale_factor = 1 / math.sqrt(q.size(-1)) 143 attn_weight = q @ k.transpose(-2, -1) * scale_factor 144 attn_weight += attn_mask 145 attn_weight = torch.softmax(attn_weight, dim=-1) 146 y = attn_weight @ v 147 148 return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 149 150 151def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 152 """ 153 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 154 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 155 """ 156 # TODO: Encounter the bug about source partition, need to investigate more on it. 157 # if n_rep == 1: 158 # return hidden_states 159 160 new_kv = [] 161 batch, n_heads, seqlen, head_dim = hidden_states.shape 162 n_heads *= n_rep 163 for h in hidden_states[0]: 164 new_kv += [h] * n_rep 165 return torch.cat(new_kv, 0).reshape(batch, n_heads, seqlen, head_dim) 166 167 168class SDPAFlex(torch.nn.Module): 169 170 def __init__( 171 self, 172 kv_cache: KVCache, 173 dim: int, 174 n_rep: int, 175 ): 176 super().__init__() 177 self.kv_cache = kv_cache 178 self.dim = dim 179 self.n_rep = n_rep 180 181 def forward( 182 self, 183 input_pos: torch.Tensor, 184 q: torch.Tensor, 185 k: torch.Tensor, 186 v: torch.Tensor, 187 bsz, 188 seqlen, 189 mask, 190 ): 191 q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 192 193 k, v = self.kv_cache.update(input_pos, k, v) 194 k = repeat_kv(k, self.n_rep) 195 v = repeat_kv(v, self.n_rep) 196 attn_mask = mask[input_pos] 197 198 scale_factor = 1 / math.sqrt(q.size(-1)) 199 attn_weight = q @ k.transpose(-2, -1) * scale_factor 200 attn_weight += attn_mask 201 attn_weight = torch.softmax(attn_weight, dim=-1) 202 y = attn_weight @ v 203 204 return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 205 206 207def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): 208 for name, child in module.named_children(): 209 if isinstance(child, SDPA): 210 setattr( 211 module, 212 name, 213 SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), 214 ) 215 else: 216 replace_sdpa_with_simple_sdpa(child) 217 return module 218 219 220def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): 221 for name, child in module.named_children(): 222 if isinstance(child, SDPA): 223 setattr( 224 module, 225 name, 226 SDPAFlex(child.kv_cache, child.dim, child.n_rep), 227 ) 228 else: 229 replace_sdpa_with_flex_sdpa(child) 230 return module 231 232 233@torch.library.custom_op("coreml::sdpa", mutates_args=()) 234def sdpa( 235 q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor 236) -> torch.Tensor: 237 """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" 238 return torch.ops.aten.scaled_dot_product_attention.default( 239 q, k, v, attn_mask=attn_mask 240 ) 241 242 243@torch.library.register_fake("coreml::sdpa") 244def _( 245 q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor 246) -> torch.Tensor: 247 """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" 248 expected_shape = list(q.shape) 249 expected_shape[-1] = v.shape[-1] 250 return q.new_empty(expected_shape) 251 252 253class SDPACoreML(torch.nn.Module): 254 """Similar to SDPASimple, but with coreml custom op to do SDPA calculation.""" 255 256 def __init__( 257 self, 258 kv_cache: KVCache, 259 dim: int, 260 head_dim: int, 261 n_rep: int, 262 ): 263 super().__init__() 264 self.kv_cache = kv_cache 265 self.dim = dim 266 self.head_dim = head_dim 267 self.n_rep = n_rep 268 269 def forward( 270 self, 271 input_pos: torch.Tensor, 272 q: torch.Tensor, 273 k: torch.Tensor, 274 v: torch.Tensor, 275 bsz, 276 seqlen, 277 mask, 278 ): 279 q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 280 k = k.transpose(1, 2) 281 v = v.transpose(1, 2) 282 283 k, v = self.kv_cache.update(input_pos, k, v) 284 attn_mask = mask[None, None, input_pos] 285 286 if self.n_rep > 1: 287 k = k.repeat_interleave(self.n_rep, dim=1) 288 v = v.repeat_interleave(self.n_rep, dim=1) 289 290 y = torch.ops.coreml.sdpa(q, k, v, attn_mask) 291 292 return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 293 294 295def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module): 296 for name, child in module.named_children(): 297 if isinstance(child, SDPA): 298 setattr( 299 module, 300 name, 301 SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep), 302 ) 303 else: 304 replace_sdpa_with_coreml_sdpa(child) 305 return module 306 307 308class KVCacheCoreML(torch.nn.Module): 309 """ 310 Rather than k_out[:, :, input_pos] = k_val, use torch.ops.aten.index_put_, 311 which can directly translate to CoreML iOS18.silce_update 312 """ 313 314 def __init__( 315 self, 316 max_batch_size: int, 317 max_seq_length: int, 318 n_heads: int, 319 head_dim: int, 320 dtype=torch.float32, 321 ): 322 super().__init__() 323 self.max_seq_length = max_seq_length 324 cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 325 326 self.max_batch_size = max_batch_size 327 self.n_heads = n_heads 328 self.head_dim = head_dim 329 self.register_buffer( 330 "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") 331 ) 332 self.register_buffer( 333 "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") 334 ) 335 336 def update( 337 self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor 338 ) -> Tuple[torch.Tensor, torch.Tensor]: 339 k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) 340 v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) 341 return k_out, v_out 342 343 344def replace_kv_cache_with_coreml_kv_cache(module: torch.nn.Module): 345 for name, child in module.named_children(): 346 if isinstance(child, KVCache): 347 setattr( 348 module, 349 name, 350 KVCacheCoreML( 351 child.max_batch_size, 352 child.max_seq_length, 353 child.n_heads, 354 child.head_dim, 355 child.k_cache.dtype, 356 ), 357 ) 358 else: 359 replace_kv_cache_with_coreml_kv_cache(child) 360 return module 361 362 363class KVCacheSimple(torch.nn.Module): 364 def __init__( 365 self, 366 max_batch_size: int, 367 max_seq_length: int, 368 n_heads: int, 369 head_dim: int, 370 dtype=torch.float32, 371 ): 372 super().__init__() 373 cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) 374 self.register_buffer( 375 "past_k_caches", 376 torch.zeros(cache_shape, dtype=dtype, device="cpu"), 377 persistent=False, 378 ) 379 self.register_buffer( 380 "past_v_caches", 381 torch.zeros(cache_shape, dtype=dtype, device="cpu"), 382 persistent=False, 383 ) 384 385 def update( 386 self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor 387 ) -> Tuple[torch.Tensor, torch.Tensor]: 388 k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) 389 v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val) 390 391 k_out = k_out.transpose(1, 2) 392 v_out = v_out.transpose(1, 2) 393 return k_out, v_out 394 395 396def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module): 397 for name, child in module.named_children(): 398 if isinstance(child, KVCache): 399 setattr( 400 module, 401 name, 402 KVCacheSimple( 403 child.max_batch_size, 404 child.max_seq_length, 405 child.n_heads, 406 child.head_dim, 407 child.k_cache.dtype, 408 ), 409 ) 410 else: 411 replace_kv_cache_with_simple_kv_cache(child) 412 return module 413 414 415def replace_causal_mask(module: torch.nn.Module): 416 for buffer_fqn_name, buffer in module.named_buffers(): 417 buffer_name = buffer_fqn_name.split(".")[-1] 418 if buffer_name == "mask": 419 max_seq_len = buffer.shape[-1] 420 mask = torch.full( 421 (max_seq_len, max_seq_len), 422 float("-inf"), 423 device="cpu", 424 ) 425 426 mask = torch.triu(mask, diagonal=1) 427 module.register_buffer(buffer_name, mask) 428 for _, child in module.named_children(): 429 replace_causal_mask(child) 430 return module 431