1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
10
11 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12 // @lint-ignore CLANGTIDY facebook-unused-include-check
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14
15 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
16
17 namespace torch {
18 namespace executor {
19
20 namespace native {
21
22 namespace {
validate_cache_params(const Tensor & quantized_value,const Tensor & quantized_cache,int64_t start_pos,int64_t seq_length)23 bool validate_cache_params(
24 const Tensor& quantized_value,
25 const Tensor& quantized_cache,
26 int64_t start_pos,
27 int64_t seq_length) {
28 ET_LOG_MSG_AND_RETURN_IF_FALSE(
29 quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
30
31 ET_LOG_MSG_AND_RETURN_IF_FALSE(
32 quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
33
34 ET_LOG_MSG_AND_RETURN_IF_FALSE(
35 start_pos < quantized_cache.size(1),
36 "start_pos must be less than cache size at dim 1");
37
38 ET_LOG_MSG_AND_RETURN_IF_FALSE(
39 (start_pos + seq_length) <= quantized_cache.size(1),
40 "start_post + seq_length must be less than max seq length supported by cache."
41 "start pos: %" PRId64 ", seq_length: %" PRId64
42 "."
43 "cache size: %zd",
44 start_pos,
45 seq_length,
46 quantized_cache.size(1));
47
48 // Make sure they are in contiguous dim order
49 ET_LOG_MSG_AND_RETURN_IF_FALSE(
50 is_contiguous_dim_order(
51 quantized_cache.dim_order().data(), quantized_cache.dim()),
52 "quantized cache must be in contiguous dim order");
53
54 ET_LOG_MSG_AND_RETURN_IF_FALSE(
55 is_contiguous_dim_order(
56 quantized_value.dim_order().data(), quantized_value.dim()),
57 "quantized value must be in contiguous dim order");
58
59 return true;
60 }
61 } // anonymous namespace
62
update_quantized_cache_out(RuntimeContext & ctx,const Tensor & value,Tensor & cache,const int64_t start_pos,Tensor & output)63 Tensor& update_quantized_cache_out(
64 RuntimeContext& ctx,
65 const Tensor& value,
66 Tensor& cache,
67 const int64_t start_pos,
68 Tensor& output) {
69 (void)ctx;
70 int64_t seq_len = value.size(1);
71 ET_KERNEL_CHECK(
72 ctx,
73 validate_cache_params(value, cache, start_pos, seq_len),
74 InvalidArgument,
75 output);
76
77 ET_CHECK_MSG(
78 value.size(0) == cache.size(0),
79 "projected_value batch size should be equal to the cache batch size.");
80 ET_CHECK_MSG(
81 value.size(2) == cache.size(2),
82 "projected_value number of heads should be equal to the cache number of heads.");
83 ET_CHECK_MSG(
84 value.size(3) == cache.size(3),
85 "projected_value embedding dimension should be equal to the cache embedding dimension.");
86 ET_CHECK_MSG(
87 value.element_size() == cache.element_size(),
88 "projected_value data type size should be equal to the cache data type size.");
89
90 ET_CHECK_MSG(
91 is_contiguous_dim_order(value.dim_order().data(), value.dim()),
92 "projected value must be in contiguous dim order");
93 ET_CHECK_MSG(
94 is_contiguous_dim_order(cache.dim_order().data(), cache.dim()),
95 "projected value must be in contiguous dim order");
96
97 const void* value_data = value.const_data_ptr();
98 void* cache_data = cache.mutable_data_ptr();
99
100 ET_CHECK_MSG(value_data, "projected_value data is null");
101 ET_CHECK_MSG(cache_data, "cache data is null");
102
103 auto cache_strides = cache.strides();
104 exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
105 exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
106
107 auto value_strides = value.strides();
108 exec_aten::StridesType value_batch_dim_stride = value_strides[0];
109
110 exec_aten::SizesType num_bytes_to_copy =
111 (value.numel() / value.size(0)) * value.element_size();
112
113 for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
114 exec_aten::SizesType cache_pos_offset =
115 (batch_line * cache_batch_dim_stride +
116 start_pos * cache_seq_dim_stride) *
117 cache.element_size();
118 exec_aten::SizesType value_pos_offset =
119 (batch_line * value_batch_dim_stride) * cache.element_size();
120
121 std::memcpy(
122 (uint8_t*)cache_data + cache_pos_offset,
123 (uint8_t*)value_data + value_pos_offset,
124 num_bytes_to_copy);
125 }
126
127 // Noone uses output. Just a placeholder.
128 return output;
129 }
130 } // namespace native
131 } // namespace executor
132 } // namespace torch
133
134 // Really this is just an inplace tensor update op
135 // which makes assumption on the rank of a tensor,
136 // and the dim order (memory layout) of the tensor.
137 // Furthermore assumes that the indexing is along
138 // sequence dimension (dim 1) of the tensor.
139 // In later diffs will rename this to update_cache.
140 EXECUTORCH_LIBRARY(
141 llama,
142 "update_quantized_cache.out",
143 torch::executor::native::update_quantized_cache_out);
144