xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_embedding4b.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 using Scalar = exec_aten::Scalar;
21 using ScalarType = exec_aten::ScalarType;
22 
23 /**
24  * Retrieves the embeddings specified by indices, dequantizes them, and stores
25  * them in out. The weight is quantized per channel, with a scale and zero_point
26  * for each embedding.
27  *
28  * Corresponds as the out variant to torch.ops.quantized.embedding_4bit
29  *
30  * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
31  * metadata that is passed around which can be useful for pattern matching. See
32  * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
33  * info.
34  */
quantized_embedding_4bit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out)35 Tensor& quantized_embedding_4bit_out(
36     // TODO Evaluate whether this name is appropriate for an operator that takes
37     // non quant input and returns fp output
38     const Tensor& weight,
39     const Tensor& weight_scales,
40     const exec_aten::optional<Tensor>& opt_weight_zero_points,
41     const int64_t weight_quant_min,
42     const int64_t weight_quant_max,
43     const Tensor& indices,
44     Tensor& out) {
45   return quantized_embedding_xbit_out(
46       weight,
47       weight_scales,
48       opt_weight_zero_points,
49       weight_quant_min,
50       weight_quant_max,
51       indices,
52       out,
53       4);
54 }
55 
quantized_embedding_4bit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out)56 Tensor& quantized_embedding_4bit_out(
57     KernelRuntimeContext& context,
58     const Tensor& weight,
59     const Tensor& weight_scales,
60     const exec_aten::optional<Tensor>& opt_weight_zero_points,
61     int64_t weight_quant_min,
62     int64_t weight_quant_max,
63     const Tensor& indices,
64     Tensor& out) {
65   return quantized_embedding_xbit_out(
66       context,
67       weight,
68       weight_scales,
69       opt_weight_zero_points,
70       weight_quant_min,
71       weight_quant_max,
72       indices,
73       out,
74       4);
75 }
76 
quantized_embedding_4bit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)77 Tensor& quantized_embedding_4bit_dtype_out(
78     // TODO Evaluate whether this name is appropriate for an operator that takes
79     // non quant input and returns fp output
80     const Tensor& weight,
81     const Tensor& weight_scales,
82     const exec_aten::optional<Tensor>& opt_weight_zero_points,
83     const int64_t weight_quant_min,
84     const int64_t weight_quant_max,
85     const Tensor& indices,
86     exec_aten::optional<ScalarType> out_dtype,
87     Tensor& out) {
88   return quantized_embedding_xbit_dtype_out(
89       weight,
90       weight_scales,
91       opt_weight_zero_points,
92       weight_quant_min,
93       weight_quant_max,
94       indices,
95       out_dtype,
96       out,
97       4);
98 }
99 
quantized_embedding_4bit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)100 Tensor& quantized_embedding_4bit_dtype_out(
101     KernelRuntimeContext& context,
102     const Tensor& weight,
103     const Tensor& weight_scales,
104     const exec_aten::optional<Tensor>& opt_weight_zero_points,
105     int64_t weight_quant_min,
106     int64_t weight_quant_max,
107     const Tensor& indices,
108     exec_aten::optional<ScalarType> out_dtype,
109     Tensor& out) {
110   return quantized_embedding_xbit_dtype_out(
111       context,
112       weight,
113       weight_scales,
114       opt_weight_zero_points,
115       weight_quant_min,
116       weight_quant_max,
117       indices,
118       out_dtype,
119       out,
120       4);
121 }
122 
123 } // namespace native
124 } // namespace executor
125 } // namespace torch
126