1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20 using Scalar = exec_aten::Scalar;
21 using ScalarType = exec_aten::ScalarType;
22
23 /**
24 * Retrieves the embeddings specified by indices, dequantizes them, and stores
25 * them in out. The weight is quantized per channel, with a scale and zero_point
26 * for each embedding.
27 *
28 * Corresponds as the out variant to torch.ops.quantized.embedding_4bit
29 *
30 * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
31 * metadata that is passed around which can be useful for pattern matching. See
32 * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
33 * info.
34 */
quantized_embedding_4bit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out)35 Tensor& quantized_embedding_4bit_out(
36 // TODO Evaluate whether this name is appropriate for an operator that takes
37 // non quant input and returns fp output
38 const Tensor& weight,
39 const Tensor& weight_scales,
40 const exec_aten::optional<Tensor>& opt_weight_zero_points,
41 const int64_t weight_quant_min,
42 const int64_t weight_quant_max,
43 const Tensor& indices,
44 Tensor& out) {
45 return quantized_embedding_xbit_out(
46 weight,
47 weight_scales,
48 opt_weight_zero_points,
49 weight_quant_min,
50 weight_quant_max,
51 indices,
52 out,
53 4);
54 }
55
quantized_embedding_4bit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out)56 Tensor& quantized_embedding_4bit_out(
57 KernelRuntimeContext& context,
58 const Tensor& weight,
59 const Tensor& weight_scales,
60 const exec_aten::optional<Tensor>& opt_weight_zero_points,
61 int64_t weight_quant_min,
62 int64_t weight_quant_max,
63 const Tensor& indices,
64 Tensor& out) {
65 return quantized_embedding_xbit_out(
66 context,
67 weight,
68 weight_scales,
69 opt_weight_zero_points,
70 weight_quant_min,
71 weight_quant_max,
72 indices,
73 out,
74 4);
75 }
76
quantized_embedding_4bit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)77 Tensor& quantized_embedding_4bit_dtype_out(
78 // TODO Evaluate whether this name is appropriate for an operator that takes
79 // non quant input and returns fp output
80 const Tensor& weight,
81 const Tensor& weight_scales,
82 const exec_aten::optional<Tensor>& opt_weight_zero_points,
83 const int64_t weight_quant_min,
84 const int64_t weight_quant_max,
85 const Tensor& indices,
86 exec_aten::optional<ScalarType> out_dtype,
87 Tensor& out) {
88 return quantized_embedding_xbit_dtype_out(
89 weight,
90 weight_scales,
91 opt_weight_zero_points,
92 weight_quant_min,
93 weight_quant_max,
94 indices,
95 out_dtype,
96 out,
97 4);
98 }
99
quantized_embedding_4bit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out)100 Tensor& quantized_embedding_4bit_dtype_out(
101 KernelRuntimeContext& context,
102 const Tensor& weight,
103 const Tensor& weight_scales,
104 const exec_aten::optional<Tensor>& opt_weight_zero_points,
105 int64_t weight_quant_min,
106 int64_t weight_quant_max,
107 const Tensor& indices,
108 exec_aten::optional<ScalarType> out_dtype,
109 Tensor& out) {
110 return quantized_embedding_xbit_dtype_out(
111 context,
112 weight,
113 weight_scales,
114 opt_weight_zero_points,
115 weight_quant_min,
116 weight_quant_max,
117 indices,
118 out_dtype,
119 out,
120 4);
121 }
122
123 } // namespace native
124 } // namespace executor
125 } // namespace torch
126