/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include namespace torch { namespace executor { namespace native { using Tensor = exec_aten::Tensor; using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; /** * Retrieves the embeddings specified by indices, dequantizes them, and stores * them in out. The weight is quantized per channel, with a scale and zero_point * for each embedding. * * Corresponds as the out variant to torch.ops.quantized.embedding_4bit * * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather * metadata that is passed around which can be useful for pattern matching. See * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more * info. */ Tensor& quantized_embedding_4bit_out( // TODO Evaluate whether this name is appropriate for an operator that takes // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, Tensor& out) { return quantized_embedding_xbit_out( weight, weight_scales, opt_weight_zero_points, weight_quant_min, weight_quant_max, indices, out, 4); } Tensor& quantized_embedding_4bit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, Tensor& out) { return quantized_embedding_xbit_out( context, weight, weight_scales, opt_weight_zero_points, weight_quant_min, weight_quant_max, indices, out, 4); } Tensor& quantized_embedding_4bit_dtype_out( // TODO Evaluate whether this name is appropriate for an operator that takes // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, exec_aten::optional out_dtype, Tensor& out) { return quantized_embedding_xbit_dtype_out( weight, weight_scales, opt_weight_zero_points, weight_quant_min, weight_quant_max, indices, out_dtype, out, 4); } Tensor& quantized_embedding_4bit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, exec_aten::optional out_dtype, Tensor& out) { return quantized_embedding_xbit_dtype_out( context, weight, weight_scales, opt_weight_zero_points, weight_quant_min, weight_quant_max, indices, out_dtype, out, 4); } } // namespace native } // namespace executor } // namespace torch