1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace impl {
13 namespace reference {
14 namespace native {
15
16 using executorch::aten::Tensor;
17 using executorch::runtime::KernelRuntimeContext;
18
19 template <typename T>
quantized_relu_(const Tensor & input,const Tensor & in_zero_point,const int64_t out_zero_point,const Tensor & out_multiplier,const Tensor & out_shift,Tensor & output)20 void quantized_relu_(
21 const Tensor& input,
22 const Tensor& in_zero_point,
23 const int64_t out_zero_point,
24 const Tensor& out_multiplier,
25 const Tensor& out_shift,
26 Tensor& output) {
27 T q_zero_point = in_zero_point.const_data_ptr<T>()[0];
28 const T* __restrict__ in = input.const_data_ptr<T>();
29 T* __restrict__ out = output.mutable_data_ptr<T>();
30
31 const int32_t* __restrict__ out_multiplier_data =
32 out_multiplier.const_data_ptr<int32_t>();
33 const int32_t* __restrict__ out_shift_data =
34 out_shift.const_data_ptr<int32_t>();
35
36 // Compute the out_scale from out_multiplier and out_shift
37 const float out_scale =
38 -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]);
39
40 for (size_t i = 0, e = input.numel(); i < e; ++i) {
41 const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0;
42 out[i] = kernels::quantize<T>(temp, out_scale, out_zero_point);
43 }
44 }
45
quantized_relu_out(KernelRuntimeContext & ctx,const Tensor & input,const Tensor & in_zero_point,const int64_t out_zero_point,const Tensor & out_multiplier,const Tensor & out_shift,Tensor & output)46 void quantized_relu_out(
47 KernelRuntimeContext& ctx,
48 const Tensor& input,
49 const Tensor& in_zero_point,
50 const int64_t out_zero_point,
51 const Tensor& out_multiplier,
52 const Tensor& out_shift,
53 Tensor& output) {
54 if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
55 quantized_relu_<uint8_t>(
56 input,
57 in_zero_point,
58 out_zero_point,
59 out_multiplier,
60 out_shift,
61 output);
62 } else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
63 quantized_relu_<int8_t>(
64 input,
65 in_zero_point,
66 out_zero_point,
67 out_multiplier,
68 out_shift,
69 output);
70 } else {
71 ET_CHECK_MSG(
72 false,
73 "Unhandled input dtype %hhd",
74 static_cast<int8_t>(input.scalar_type()));
75 }
76 }
77
78 }; // namespace native
79 }; // namespace reference
80 }; // namespace impl
81