1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_includes.h> 10*523fa7a6SAndroid Build Coastguard Worker #include <algorithm> 11*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes> 12*523fa7a6SAndroid Build Coastguard Worker #include <cmath> 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker namespace torch { 15*523fa7a6SAndroid Build Coastguard Worker namespace executor { 16*523fa7a6SAndroid Build Coastguard Worker namespace native { 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker using Tensor = exec_aten::Tensor; 19*523fa7a6SAndroid Build Coastguard Worker using Scalar = exec_aten::Scalar; 20*523fa7a6SAndroid Build Coastguard Worker using ScalarType = exec_aten::ScalarType; 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out( 23*523fa7a6SAndroid Build Coastguard Worker // TODO Evaluate whether this name is appropriate for an operator that takes 24*523fa7a6SAndroid Build Coastguard Worker // non quant input and returns fp output 25*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight, 26*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales, 27*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points, 28*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_min, 29*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_max, 30*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices, 31*523fa7a6SAndroid Build Coastguard Worker Tensor& out, 32*523fa7a6SAndroid Build Coastguard Worker int weight_nbit); 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_out( 35*523fa7a6SAndroid Build Coastguard Worker KernelRuntimeContext& context, 36*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight, 37*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales, 38*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points, 39*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_min, 40*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_max, 41*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices, 42*523fa7a6SAndroid Build Coastguard Worker Tensor& out, 43*523fa7a6SAndroid Build Coastguard Worker int weight_nbit); 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out( 46*523fa7a6SAndroid Build Coastguard Worker // TODO Evaluate whether this name is appropriate for an operator that takes 47*523fa7a6SAndroid Build Coastguard Worker // non quant input and returns fp output 48*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight, 49*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales, 50*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points, 51*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_min, 52*523fa7a6SAndroid Build Coastguard Worker const int64_t weight_quant_max, 53*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices, 54*523fa7a6SAndroid Build Coastguard Worker exec_aten::optional<ScalarType> out_dtype, 55*523fa7a6SAndroid Build Coastguard Worker Tensor& out, 56*523fa7a6SAndroid Build Coastguard Worker int weight_nbit); 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker Tensor& quantized_embedding_xbit_dtype_out( 59*523fa7a6SAndroid Build Coastguard Worker KernelRuntimeContext& context, 60*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight, 61*523fa7a6SAndroid Build Coastguard Worker const Tensor& weight_scales, 62*523fa7a6SAndroid Build Coastguard Worker const exec_aten::optional<Tensor>& opt_weight_zero_points, 63*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_min, 64*523fa7a6SAndroid Build Coastguard Worker int64_t weight_quant_max, 65*523fa7a6SAndroid Build Coastguard Worker const Tensor& indices, 66*523fa7a6SAndroid Build Coastguard Worker exec_aten::optional<ScalarType> out_dtype, 67*523fa7a6SAndroid Build Coastguard Worker Tensor& out, 68*523fa7a6SAndroid Build Coastguard Worker int weight_nbit); 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker } // namespace native 71*523fa7a6SAndroid Build Coastguard Worker } // namespace executor 72*523fa7a6SAndroid Build Coastguard Worker } // namespace torch 73