xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/sdpa.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9# Example script for exporting Llama2 to flatbuffer
10
11import math
12from typing import Tuple, Union
13
14import torch
15
16from executorch.examples.models.llama.llama_transformer import KVCache, SDPA
17from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
18    QuantizedKVCache,
19)
20
21
22class SDPACustom(torch.nn.Module):
23    def __init__(
24        self,
25        kv_cache: Union[KVCache, QuantizedKVCache],
26        dim: int,
27    ):
28        super().__init__()
29        # Custom op only supports float32 currently. Converting to/from float32 is
30        # faster than not having the op.
31        self.kv_cache = kv_cache
32        if not isinstance(kv_cache, QuantizedKVCache):
33            self.kv_cache = kv_cache.to(torch.float)
34        else:
35            assert (
36                kv_cache.cache_fp_type == torch.float32
37            ), "Only float32 is supported for custom SDPA"
38        self.dim = dim
39
40    def forward(
41        self,
42        input_pos: torch.Tensor,
43        q: torch.Tensor,
44        k: torch.Tensor,
45        v: torch.Tensor,
46        bsz,
47        seqlen,
48        mask,
49    ):
50        # Custom op only supports float32 currently. Converting to/from float32 is
51        # faster than not having the op.
52        input_dtype = q.dtype
53        q = q.to(dtype=torch.float)
54        k = k.to(dtype=torch.float)
55        v = v.to(dtype=torch.float)
56
57        k_cache = self.kv_cache.k_cache
58        v_cache = self.kv_cache.v_cache
59        if isinstance(self.kv_cache, QuantizedKVCache):
60            # updated quantize cache, scale and zero points
61            # returns dequantized kv cache
62            # Not most optimal. Optimizations to follow next
63            k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
64            output = torch.ops.llama.custom_sdpa(
65                q,
66                k_cache,
67                v_cache,
68                input_pos[0].item(),
69                None,  # Attention mask
70                0,  # dropout probability. Ignored by the code
71                True,  # is_causal
72            )
73        else:
74            output = torch.ops.llama.sdpa_with_kv_cache(
75                q,
76                k,
77                v,
78                k_cache,
79                v_cache,
80                input_pos[0].item(),
81                seqlen,
82                None,  # Attention mask
83                0,  # dropout probability. Ignored by the code
84                True,  # is_causal
85            )
86        return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
87
88
89def _replace_sdpa_with_custom_op(module: torch.nn.Module):
90    for name, child in module.named_children():
91        if isinstance(child, SDPA):
92            setattr(
93                module,
94                name,
95                SDPACustom(child.kv_cache, child.dim),
96            )
97        else:
98            _replace_sdpa_with_custom_op(child)
99
100
101def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102    from executorch.extension.llm.custom_ops import sdpa_with_kv_cache  # noqa
103
104    _replace_sdpa_with_custom_op(module)
105    return module
106
107
108class SDPASimple(torch.nn.Module):
109
110    def __init__(
111        self,
112        kv_cache: KVCache,
113        dim: int,
114        head_dim: int,
115        n_rep: int,
116    ):
117        super().__init__()
118        self.kv_cache = kv_cache
119        self.dim = dim
120        self.head_dim = head_dim
121        self.n_rep = n_rep
122
123    def forward(
124        self,
125        input_pos: torch.Tensor,
126        q: torch.Tensor,
127        k: torch.Tensor,
128        v: torch.Tensor,
129        bsz,
130        seqlen,
131        mask,
132    ):
133        q = q.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
134        k = k.transpose(1, 2)
135        v = v.transpose(1, 2)
136
137        k, v = self.kv_cache.update(input_pos, k, v)
138        attn_mask = mask[None, None, input_pos]
139
140        k = k.repeat_interleave(self.n_rep, dim=1)
141        v = v.repeat_interleave(self.n_rep, dim=1)
142        scale_factor = 1 / math.sqrt(q.size(-1))
143        attn_weight = q @ k.transpose(-2, -1) * scale_factor
144        attn_weight += attn_mask
145        attn_weight = torch.softmax(attn_weight, dim=-1)
146        y = attn_weight @ v
147
148        return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
149
150
151def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
152    """
153    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
154    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
155    """
156    # TODO: Encounter the bug about source partition, need to investigate more on it.
157    # if n_rep == 1:
158    #     return hidden_states
159
160    new_kv = []
161    batch, n_heads, seqlen, head_dim = hidden_states.shape
162    n_heads *= n_rep
163    for h in hidden_states[0]:
164        new_kv += [h] * n_rep
165    return torch.cat(new_kv, 0).reshape(batch, n_heads, seqlen, head_dim)
166
167
168class SDPAFlex(torch.nn.Module):
169
170    def __init__(
171        self,
172        kv_cache: KVCache,
173        dim: int,
174        n_rep: int,
175    ):
176        super().__init__()
177        self.kv_cache = kv_cache
178        self.dim = dim
179        self.n_rep = n_rep
180
181    def forward(
182        self,
183        input_pos: torch.Tensor,
184        q: torch.Tensor,
185        k: torch.Tensor,
186        v: torch.Tensor,
187        bsz,
188        seqlen,
189        mask,
190    ):
191        q = q.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
192
193        k, v = self.kv_cache.update(input_pos, k, v)
194        k = repeat_kv(k, self.n_rep)
195        v = repeat_kv(v, self.n_rep)
196        attn_mask = mask[input_pos]
197
198        scale_factor = 1 / math.sqrt(q.size(-1))
199        attn_weight = q @ k.transpose(-2, -1) * scale_factor
200        attn_weight += attn_mask
201        attn_weight = torch.softmax(attn_weight, dim=-1)
202        y = attn_weight @ v
203
204        return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
205
206
207def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
208    for name, child in module.named_children():
209        if isinstance(child, SDPA):
210            setattr(
211                module,
212                name,
213                SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
214            )
215        else:
216            replace_sdpa_with_simple_sdpa(child)
217    return module
218
219
220def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
221    for name, child in module.named_children():
222        if isinstance(child, SDPA):
223            setattr(
224                module,
225                name,
226                SDPAFlex(child.kv_cache, child.dim, child.n_rep),
227            )
228        else:
229            replace_sdpa_with_flex_sdpa(child)
230    return module
231
232
233@torch.library.custom_op("coreml::sdpa", mutates_args=())
234def sdpa(
235    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
236) -> torch.Tensor:
237    """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
238    return torch.ops.aten.scaled_dot_product_attention.default(
239        q, k, v, attn_mask=attn_mask
240    )
241
242
243@torch.library.register_fake("coreml::sdpa")
244def _(
245    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
246) -> torch.Tensor:
247    """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
248    expected_shape = list(q.shape)
249    expected_shape[-1] = v.shape[-1]
250    return q.new_empty(expected_shape)
251
252
253class SDPACoreML(torch.nn.Module):
254    """Similar to SDPASimple, but with coreml custom op to do SDPA calculation."""
255
256    def __init__(
257        self,
258        kv_cache: KVCache,
259        dim: int,
260        head_dim: int,
261        n_rep: int,
262    ):
263        super().__init__()
264        self.kv_cache = kv_cache
265        self.dim = dim
266        self.head_dim = head_dim
267        self.n_rep = n_rep
268
269    def forward(
270        self,
271        input_pos: torch.Tensor,
272        q: torch.Tensor,
273        k: torch.Tensor,
274        v: torch.Tensor,
275        bsz,
276        seqlen,
277        mask,
278    ):
279        q = q.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
280        k = k.transpose(1, 2)
281        v = v.transpose(1, 2)
282
283        k, v = self.kv_cache.update(input_pos, k, v)
284        attn_mask = mask[None, None, input_pos]
285
286        if self.n_rep > 1:
287            k = k.repeat_interleave(self.n_rep, dim=1)
288            v = v.repeat_interleave(self.n_rep, dim=1)
289
290        y = torch.ops.coreml.sdpa(q, k, v, attn_mask)
291
292        return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
293
294
295def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
296    for name, child in module.named_children():
297        if isinstance(child, SDPA):
298            setattr(
299                module,
300                name,
301                SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
302            )
303        else:
304            replace_sdpa_with_coreml_sdpa(child)
305    return module
306
307
308class KVCacheCoreML(torch.nn.Module):
309    """
310    Rather than k_out[:, :, input_pos] = k_val, use torch.ops.aten.index_put_,
311    which can directly translate to CoreML iOS18.silce_update
312    """
313
314    def __init__(
315        self,
316        max_batch_size: int,
317        max_seq_length: int,
318        n_heads: int,
319        head_dim: int,
320        dtype=torch.float32,
321    ):
322        super().__init__()
323        self.max_seq_length = max_seq_length
324        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
325
326        self.max_batch_size = max_batch_size
327        self.n_heads = n_heads
328        self.head_dim = head_dim
329        self.register_buffer(
330            "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
331        )
332        self.register_buffer(
333            "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
334        )
335
336    def update(
337        self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
338    ) -> Tuple[torch.Tensor, torch.Tensor]:
339        k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
340        v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
341        return k_out, v_out
342
343
344def replace_kv_cache_with_coreml_kv_cache(module: torch.nn.Module):
345    for name, child in module.named_children():
346        if isinstance(child, KVCache):
347            setattr(
348                module,
349                name,
350                KVCacheCoreML(
351                    child.max_batch_size,
352                    child.max_seq_length,
353                    child.n_heads,
354                    child.head_dim,
355                    child.k_cache.dtype,
356                ),
357            )
358        else:
359            replace_kv_cache_with_coreml_kv_cache(child)
360    return module
361
362
363class KVCacheSimple(torch.nn.Module):
364    def __init__(
365        self,
366        max_batch_size: int,
367        max_seq_length: int,
368        n_heads: int,
369        head_dim: int,
370        dtype=torch.float32,
371    ):
372        super().__init__()
373        cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
374        self.register_buffer(
375            "past_k_caches",
376            torch.zeros(cache_shape, dtype=dtype, device="cpu"),
377            persistent=False,
378        )
379        self.register_buffer(
380            "past_v_caches",
381            torch.zeros(cache_shape, dtype=dtype, device="cpu"),
382            persistent=False,
383        )
384
385    def update(
386        self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
387    ) -> Tuple[torch.Tensor, torch.Tensor]:
388        k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val)
389        v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val)
390
391        k_out = k_out.transpose(1, 2)
392        v_out = v_out.transpose(1, 2)
393        return k_out, v_out
394
395
396def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module):
397    for name, child in module.named_children():
398        if isinstance(child, KVCache):
399            setattr(
400                module,
401                name,
402                KVCacheSimple(
403                    child.max_batch_size,
404                    child.max_seq_length,
405                    child.n_heads,
406                    child.head_dim,
407                    child.k_cache.dtype,
408                ),
409            )
410        else:
411            replace_kv_cache_with_simple_kv_cache(child)
412    return module
413
414
415def replace_causal_mask(module: torch.nn.Module):
416    for buffer_fqn_name, buffer in module.named_buffers():
417        buffer_name = buffer_fqn_name.split(".")[-1]
418        if buffer_name == "mask":
419            max_seq_len = buffer.shape[-1]
420            mask = torch.full(
421                (max_seq_len, max_seq_len),
422                float("-inf"),
423                device="cpu",
424            )
425
426            mask = torch.triu(mask, diagonal=1)
427            module.register_buffer(buffer_name, mask)
428    for _, child in module.named_children():
429        replace_causal_mask(child)
430    return module
431