xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_update_quantized_cache.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
10 
11 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12 // @lint-ignore CLANGTIDY facebook-unused-include-check
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 
15 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
16 
17 namespace torch {
18 namespace executor {
19 
20 namespace native {
21 
22 namespace {
validate_cache_params(const Tensor & quantized_value,const Tensor & quantized_cache,int64_t start_pos,int64_t seq_length)23 bool validate_cache_params(
24     const Tensor& quantized_value,
25     const Tensor& quantized_cache,
26     int64_t start_pos,
27     int64_t seq_length) {
28   ET_LOG_MSG_AND_RETURN_IF_FALSE(
29       quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
30 
31   ET_LOG_MSG_AND_RETURN_IF_FALSE(
32       quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
33 
34   ET_LOG_MSG_AND_RETURN_IF_FALSE(
35       start_pos < quantized_cache.size(1),
36       "start_pos must be less than cache size at dim 1");
37 
38   ET_LOG_MSG_AND_RETURN_IF_FALSE(
39       (start_pos + seq_length) <= quantized_cache.size(1),
40       "start_post + seq_length must be less than max seq length supported by cache."
41       "start pos: %" PRId64 ", seq_length: %" PRId64
42       "."
43       "cache size: %zd",
44       start_pos,
45       seq_length,
46       quantized_cache.size(1));
47 
48   // Make sure they are in contiguous dim order
49   ET_LOG_MSG_AND_RETURN_IF_FALSE(
50       is_contiguous_dim_order(
51           quantized_cache.dim_order().data(), quantized_cache.dim()),
52       "quantized cache must be in contiguous dim order");
53 
54   ET_LOG_MSG_AND_RETURN_IF_FALSE(
55       is_contiguous_dim_order(
56           quantized_value.dim_order().data(), quantized_value.dim()),
57       "quantized value must be in contiguous dim order");
58 
59   return true;
60 }
61 } // anonymous namespace
62 
update_quantized_cache_out(RuntimeContext & ctx,const Tensor & value,Tensor & cache,const int64_t start_pos,Tensor & output)63 Tensor& update_quantized_cache_out(
64     RuntimeContext& ctx,
65     const Tensor& value,
66     Tensor& cache,
67     const int64_t start_pos,
68     Tensor& output) {
69   (void)ctx;
70   int64_t seq_len = value.size(1);
71   ET_KERNEL_CHECK(
72       ctx,
73       validate_cache_params(value, cache, start_pos, seq_len),
74       InvalidArgument,
75       output);
76 
77   ET_CHECK_MSG(
78       value.size(0) == cache.size(0),
79       "projected_value batch size should be equal to the cache batch size.");
80   ET_CHECK_MSG(
81       value.size(2) == cache.size(2),
82       "projected_value number of heads should be equal to the cache number of heads.");
83   ET_CHECK_MSG(
84       value.size(3) == cache.size(3),
85       "projected_value embedding dimension should be equal to the cache embedding dimension.");
86   ET_CHECK_MSG(
87       value.element_size() == cache.element_size(),
88       "projected_value data type size should be equal to the cache data type size.");
89 
90   ET_CHECK_MSG(
91       is_contiguous_dim_order(value.dim_order().data(), value.dim()),
92       "projected value must be in contiguous dim order");
93   ET_CHECK_MSG(
94       is_contiguous_dim_order(cache.dim_order().data(), cache.dim()),
95       "projected value must be in contiguous dim order");
96 
97   const void* value_data = value.const_data_ptr();
98   void* cache_data = cache.mutable_data_ptr();
99 
100   ET_CHECK_MSG(value_data, "projected_value data is null");
101   ET_CHECK_MSG(cache_data, "cache data is null");
102 
103   auto cache_strides = cache.strides();
104   exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
105   exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
106 
107   auto value_strides = value.strides();
108   exec_aten::StridesType value_batch_dim_stride = value_strides[0];
109 
110   exec_aten::SizesType num_bytes_to_copy =
111       (value.numel() / value.size(0)) * value.element_size();
112 
113   for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
114     exec_aten::SizesType cache_pos_offset =
115         (batch_line * cache_batch_dim_stride +
116          start_pos * cache_seq_dim_stride) *
117         cache.element_size();
118     exec_aten::SizesType value_pos_offset =
119         (batch_line * value_batch_dim_stride) * cache.element_size();
120 
121     std::memcpy(
122         (uint8_t*)cache_data + cache_pos_offset,
123         (uint8_t*)value_data + value_pos_offset,
124         num_bytes_to_copy);
125   }
126 
127   // Noone uses output. Just a placeholder.
128   return output;
129 }
130 } // namespace native
131 } // namespace executor
132 } // namespace torch
133 
134 // Really this is just an inplace tensor update op
135 // which makes assumption on the rank of a tensor,
136 // and the dim order (memory layout) of the tensor.
137 // Furthermore assumes that the indexing is along
138 // sequence dimension (dim 1) of the tensor.
139 // In later diffs will rename this to update_cache.
140 EXECUTORCH_LIBRARY(
141     llama,
142     "update_quantized_cache.out",
143     torch::executor::native::update_quantized_cache_out);
144