xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/sdpa_with_kv_cache.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch
8# C++ APIs for registration so here we need to import the shared library.
9# This is only needed for OSS.
10
11# pyre-unsafe
12
13import logging
14from pathlib import Path
15
16import torch
17
18from torch.library import impl
19
20# TODO rename this file to custom_ops_meta_registration.py
21try:
22    op = torch.ops.llama.sdpa_with_kv_cache.default
23    assert op is not None
24    op2 = torch.ops.llama.fast_hadamard_transform.default
25    assert op2 is not None
26except:
27    libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
28    assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
29    logging.info(f"Loading custom ops library: {libs[0]}")
30    torch.ops.load_library(libs[0])
31    op = torch.ops.llama.sdpa_with_kv_cache.default
32    assert op is not None
33    op2 = torch.ops.llama.fast_hadamard_transform.default
34    assert op2 is not None
35
36custom_ops_lib = torch.library.Library("llama", "IMPL")
37
38
39def _validate_params(
40    query,
41    key,
42    value,
43    key_cache,
44    value_cache,
45    start_pos,
46    seq_len,
47    attn_mask,
48    drpout_p,
49    is_causal,
50    scale,
51):
52    assert (
53        query.dim() == 4
54    ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions."
55    assert (
56        key.dim() == 4
57    ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions."
58    assert (
59        value.dim() == 4
60    ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
61
62    assert (
63        query.dtype == torch.float32
64    ), f"Expected query to be float32 but got {query.dtype}"
65    assert key.dtype == torch.float32, f"Expected key to be float32 but got {key.dtype}"
66    assert (
67        value.dtype == torch.float32
68    ), f"Expected value to be float32 but got {value.dtype}"
69
70    assert (
71        key_cache.dim() == 4
72    ), f"Expected key_cache to be 4 dimensional but got {key_cache.dim()}"
73    assert (
74        value_cache.dim() == 4
75    ), f"Expected value_cache to be 4 dimensional but got {value_cache.dim()}"
76
77    assert (
78        key_cache.dtype == torch.float32
79    ), f"Expected key_cache to be float32 but got {key_cache.dtype}"
80    assert (
81        value_cache.dtype == torch.float32
82    ), f"Expected value_cache to be float32 but got {value_cache.dtype}"
83
84    assert (
85        key_cache.size() == value_cache.size()
86    ), f"Key cache and value cache must have same size but got {key_cache.size()} and {value_cache.size()}"
87
88    # These asserts are real but they require me to add constrain_as_size/value calls to the model and I dont want to do that right now
89    # assert start_pos < key_cache.size(
90    #     1
91    # ), f"Start position {start_pos} must be less than sequence length {key_cache.size(2)}"
92    # assert (start_pos + seq_len) < key_cache.size(
93    #     1
94    # ), f"Start position  + length = {start_pos + seq_len} must be less than sequence length {key_cache.size(2)}"
95
96    if attn_mask is not None:
97        assert (
98            attn_mask.dim() == 2
99        ), f"Expected attn_mask to be 2 dimensional but got {attn_mask.dim()} dimensions."
100        assert (attn_mask.dtype == torch.float32) or (
101            attn_mask.dtype == torch.float16
102        ), f"Expected attn_mask to be float but got {attn_mask.dtype}"
103
104
105@impl(custom_ops_lib, "sdpa_with_kv_cache", "Meta")
106def sdpa_with_kv_cache_meta(
107    query,
108    key,
109    value,
110    key_cache,
111    value_cache,
112    start_pos,
113    seq_len,
114    attn_mask=None,
115    drpout_p=0.0,
116    is_causal=False,
117    scale=None,
118):
119    _validate_params(
120        query,
121        key,
122        value,
123        key_cache,
124        value_cache,
125        start_pos,
126        seq_len,
127        attn_mask,
128        drpout_p,
129        is_causal,
130        scale,
131    )
132
133    return torch.empty_like(query)
134
135
136@impl(custom_ops_lib, "fast_hadamard_transform", "Meta")
137def fast_hadamard_transform_meta(mat):
138    # assert(mat.strides[-1] == 1, "input matrix must be contiguous in the last dimension!")
139    # assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
140    # assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
141    return torch.empty_like(mat)
142
143
144@impl(custom_ops_lib, "custom_sdpa", "Meta")
145def custom_sdpa(
146    query,
147    key_cache,
148    value_cache,
149    start_pos,
150    attn_mask=None,
151    drpout_p=0.0,
152    is_causal=False,
153    scale=None,
154):
155    seq_len = query.size(1)
156    _validate_params(
157        query,
158        key_cache,
159        value_cache,
160        key_cache,
161        value_cache,
162        start_pos,
163        seq_len,
164        attn_mask,
165        drpout_p,
166        is_causal,
167        scale,
168    )
169
170    return torch.empty_like(query)
171
172
173def _validate_update_cache_params(
174    value,
175    cache,
176    start_pos,
177):
178    seq_len = value.size(1)
179    assert (
180        value.dim() == 4
181    ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
182
183    assert (
184        value.dtype == cache.dtype
185    ), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"
186
187    for i in [0, 2, 3]:
188        assert value.size(i) == cache.size(
189            i
190        ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
191
192    torch._check_is_size(start_pos)
193    # Setting to arbitrary limit of 256 for now since there is no way
194    # to plumb this information from model config
195    torch._check(start_pos < cache.size(1))
196    assert start_pos < cache.size(
197        1
198    ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
199
200    torch._check((start_pos + seq_len) < cache.size(1))
201    assert (start_pos + seq_len) < cache.size(
202        1
203    ), f"Start position  + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
204
205
206@impl(custom_ops_lib, "update_quantized_cache", "Meta")
207def update_quantized_cache_meta(
208    value,
209    cache,
210    start_pos,
211):
212    _validate_update_cache_params(
213        value,
214        cache,
215        start_pos,
216    )
217
218    # Update cache doesnt really return anything but I dont know a better
219    # workaround. Should we just return cache instead? But I am afraid that
220    # will result in extra memory allocation
221    return torch.empty((1,), dtype=value.dtype, device="meta")
222