xref: /aosp_15_r20/external/executorch/backends/cadence/reference/operators/quantized_relu_out.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace impl {
13 namespace reference {
14 namespace native {
15 
16 using executorch::aten::Tensor;
17 using executorch::runtime::KernelRuntimeContext;
18 
19 template <typename T>
quantized_relu_(const Tensor & input,const Tensor & in_zero_point,const int64_t out_zero_point,const Tensor & out_multiplier,const Tensor & out_shift,Tensor & output)20 void quantized_relu_(
21     const Tensor& input,
22     const Tensor& in_zero_point,
23     const int64_t out_zero_point,
24     const Tensor& out_multiplier,
25     const Tensor& out_shift,
26     Tensor& output) {
27   T q_zero_point = in_zero_point.const_data_ptr<T>()[0];
28   const T* __restrict__ in = input.const_data_ptr<T>();
29   T* __restrict__ out = output.mutable_data_ptr<T>();
30 
31   const int32_t* __restrict__ out_multiplier_data =
32       out_multiplier.const_data_ptr<int32_t>();
33   const int32_t* __restrict__ out_shift_data =
34       out_shift.const_data_ptr<int32_t>();
35 
36   // Compute the out_scale from out_multiplier and out_shift
37   const float out_scale =
38       -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]);
39 
40   for (size_t i = 0, e = input.numel(); i < e; ++i) {
41     const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0;
42     out[i] = kernels::quantize<T>(temp, out_scale, out_zero_point);
43   }
44 }
45 
quantized_relu_out(KernelRuntimeContext & ctx,const Tensor & input,const Tensor & in_zero_point,const int64_t out_zero_point,const Tensor & out_multiplier,const Tensor & out_shift,Tensor & output)46 void quantized_relu_out(
47     KernelRuntimeContext& ctx,
48     const Tensor& input,
49     const Tensor& in_zero_point,
50     const int64_t out_zero_point,
51     const Tensor& out_multiplier,
52     const Tensor& out_shift,
53     Tensor& output) {
54   if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
55     quantized_relu_<uint8_t>(
56         input,
57         in_zero_point,
58         out_zero_point,
59         out_multiplier,
60         out_shift,
61         output);
62   } else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
63     quantized_relu_<int8_t>(
64         input,
65         in_zero_point,
66         out_zero_point,
67         out_multiplier,
68         out_shift,
69         output);
70   } else {
71     ET_CHECK_MSG(
72         false,
73         "Unhandled input dtype %hhd",
74         static_cast<int8_t>(input.scalar_type()));
75   }
76 }
77 
78 }; // namespace native
79 }; // namespace reference
80 }; // namespace impl
81