xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/embeddingxb.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_includes.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
11*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes>
12*523fa7a6SAndroid Build Coastguard Worker #include <cmath>
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker namespace torch {
15*523fa7a6SAndroid Build Coastguard Worker namespace executor {
16*523fa7a6SAndroid Build Coastguard Worker namespace native {
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker using Tensor = exec_aten::Tensor;
19*523fa7a6SAndroid Build Coastguard Worker using Scalar = exec_aten::Scalar;
20*523fa7a6SAndroid Build Coastguard Worker using ScalarType = exec_aten::ScalarType;
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out(
23*523fa7a6SAndroid Build Coastguard Worker     // TODO Evaluate whether this name is appropriate for an operator that takes
24*523fa7a6SAndroid Build Coastguard Worker     // non quant input and returns fp output
25*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
26*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
27*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
28*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_min,
29*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_max,
30*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
31*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
32*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit);
33*523fa7a6SAndroid Build Coastguard Worker 
34*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out(
35*523fa7a6SAndroid Build Coastguard Worker     KernelRuntimeContext& context,
36*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
37*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
38*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
39*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_min,
40*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_max,
41*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
42*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
43*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit);
44*523fa7a6SAndroid Build Coastguard Worker 
45*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out(
46*523fa7a6SAndroid Build Coastguard Worker     // TODO Evaluate whether this name is appropriate for an operator that takes
47*523fa7a6SAndroid Build Coastguard Worker     // non quant input and returns fp output
48*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
49*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
50*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
51*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_min,
52*523fa7a6SAndroid Build Coastguard Worker     const int64_t weight_quant_max,
53*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
54*523fa7a6SAndroid Build Coastguard Worker     exec_aten::optional<ScalarType> out_dtype,
55*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
56*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit);
57*523fa7a6SAndroid Build Coastguard Worker 
58*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out(
59*523fa7a6SAndroid Build Coastguard Worker     KernelRuntimeContext& context,
60*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight,
61*523fa7a6SAndroid Build Coastguard Worker     const Tensor& weight_scales,
62*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::optional<Tensor>& opt_weight_zero_points,
63*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_min,
64*523fa7a6SAndroid Build Coastguard Worker     int64_t weight_quant_max,
65*523fa7a6SAndroid Build Coastguard Worker     const Tensor& indices,
66*523fa7a6SAndroid Build Coastguard Worker     exec_aten::optional<ScalarType> out_dtype,
67*523fa7a6SAndroid Build Coastguard Worker     Tensor& out,
68*523fa7a6SAndroid Build Coastguard Worker     int weight_nbit);
69*523fa7a6SAndroid Build Coastguard Worker 
70*523fa7a6SAndroid Build Coastguard Worker } // namespace native
71*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
72*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
73