xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/softmax.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
17 
18 #include <algorithm>
19 #include <limits>
20 
21 #include "fixedpoint/fixedpoint.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/cppmath.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/op_macros.h"
27 
28 namespace tflite {
29 namespace reference_ops {
30 
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)31 inline void Softmax(const SoftmaxParams& params,
32                     const RuntimeShape& input_shape, const float* input_data,
33                     const RuntimeShape& output_shape, float* output_data) {
34   const int trailing_dim = input_shape.DimensionsCount() - 1;
35   const int outer_size =
36       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
37   const int depth =
38       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
39 
40   for (int i = 0; i < outer_size; ++i) {
41     // Find max element value which we'll use to ensure numerical stability
42     // taking advantage of the following equality:
43     // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
44     float max = std::numeric_limits<float>::lowest();
45     for (int c = 0; c < depth; ++c) {
46       max = std::max(max, input_data[i * depth + c]);
47     }
48 
49     // Compute sum.
50     float sum = 0.f;
51     for (int c = 0; c < depth; ++c) {
52       const float exp_c = std::exp((input_data[i * depth + c] - max) *
53                                    static_cast<float>(params.beta));
54       output_data[i * depth + c] = exp_c;
55       sum += exp_c;
56     }
57 
58     // Compute result.
59     for (int c = 0; c < depth; ++c) {
60       output_data[i * depth + c] = output_data[i * depth + c] / sum;
61     }
62   }
63 }
64 
65 // Quantized softmax with int8_t/uint8_t input and int8_t/uint8_t/int16_t
66 // output.
67 template <typename InputT, typename OutputT>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const InputT * input_data,const RuntimeShape & output_shape,OutputT * output_data)68 inline void Softmax(const SoftmaxParams& params,
69                     const RuntimeShape& input_shape, const InputT* input_data,
70                     const RuntimeShape& output_shape, OutputT* output_data) {
71   const int32_t input_beta_multiplier = params.input_multiplier;
72   const int32_t input_beta_left_shift = params.input_left_shift;
73   const int diff_min = params.diff_min;
74   // The representation chosen for the input to the exp() function is Q5.26.
75   // We need to leave extra space since values that we skip might be as large as
76   // -32 before multiplying by input_beta_multiplier, and therefore as large as
77   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
78   // accumulation, but exp(-16) definitely is.
79   static const int kScaledDiffIntegerBits = 5;
80   static const int kAccumulationIntegerBits = 12;
81   using FixedPointScaledDiff =
82       gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
83   using FixedPointAccum =
84       gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
85   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
86 
87   const int trailing_dim = input_shape.DimensionsCount() - 1;
88   const int outer_size =
89       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
90   const int depth =
91       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
92 
93   for (int i = 0; i < outer_size; ++i) {
94     InputT max_in_row = std::numeric_limits<InputT>::min();
95     for (int c = 0; c < depth; ++c) {
96       max_in_row = std::max(max_in_row, input_data[i * depth + c]);
97     }
98 
99     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
100     for (int c = 0; c < depth; ++c) {
101       int32_t input_diff =
102           static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
103       if (input_diff >= diff_min) {
104         const int32_t input_diff_rescaled =
105             MultiplyByQuantizedMultiplierGreaterThanOne(
106                 input_diff, input_beta_multiplier, input_beta_left_shift);
107         const FixedPointScaledDiff scaled_diff_f8 =
108             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
109         sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
110                                         exp_on_negative_values(scaled_diff_f8));
111       }
112     }
113 
114     int num_bits_over_unit;
115     FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
116         sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
117 
118     for (int c = 0; c < depth; ++c) {
119       int32_t input_diff =
120           static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
121       if (input_diff >= diff_min) {
122         const int32_t input_diff_rescaled =
123             MultiplyByQuantizedMultiplierGreaterThanOne(
124                 input_diff, input_beta_multiplier, input_beta_left_shift);
125         const FixedPointScaledDiff scaled_diff_f8 =
126             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
127 
128         FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
129         int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
130             (shifted_scale * exp_in_0).raw(),
131             num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
132 
133         const int32_t shifted_output =
134             unsat_output +
135             static_cast<int32_t>(std::numeric_limits<OutputT>::min());
136 
137         output_data[i * depth + c] = static_cast<OutputT>(std::max(
138             std::min(shifted_output,
139                      static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
140             static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
141       } else {
142         output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
143       }
144     }
145   }
146 }
147 
148 // Computes exp(input - max_input)
SoftMaxCalculateExp(const SoftmaxParams & params,const int16_t * input_data,const int depth,int16_t max_in_row,int i,int c)149 inline int16_t SoftMaxCalculateExp(const SoftmaxParams& params,
150                                    const int16_t* input_data, const int depth,
151                                    int16_t max_in_row, int i, int c) {
152   int32_t input_diff = input_data[i * depth + c] - max_in_row;
153   // scale the input_diff such that [-65535, 0] correspond to [-10.0, 0.0]
154   // exp lut generated with range [-10, 0], as exp(-10) is negligible.
155   int32_t scaled_diff = MultiplyByQuantizedMultiplier(
156       input_diff, params.input_multiplier, params.input_left_shift);
157   // recenter to [-32768, 32767]
158   int32_t sym_scaled_diff = scaled_diff + 32767;
159   int16_t sat_sym_scaled_diff =
160       std::min(std::max(sym_scaled_diff, static_cast<int32_t>(-32768)),
161                static_cast<int32_t>(32767));
162   // apply the exp() LUT activation function
163   return lut_lookup(sat_sym_scaled_diff, params.exp_lut);
164 }
165 // Quantized softmax with int16_t input and int16_t output.
SoftmaxInt16(const SoftmaxParams & params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,int16_t * output_data)166 inline void SoftmaxInt16(const SoftmaxParams& params,
167                          const RuntimeShape& input_shape,
168                          const int16_t* input_data,
169                          const RuntimeShape& output_shape,
170                          int16_t* output_data) {
171   const int trailing_dim = input_shape.DimensionsCount() - 1;
172   const int outer_size =
173       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
174   const int depth =
175       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
176 
177   for (int i = 0; i < outer_size; ++i) {
178     // Find the largest element
179     int16_t max_in_row = std::numeric_limits<int16_t>::min();
180     for (int c = 0; c < depth; ++c) {
181       max_in_row = std::max(max_in_row, input_data[i * depth + c]);
182     }
183 
184     // This loops computes the exp values and their sum. We will need the exp
185     // values later on in the function so we cache them in the output_data
186     // buffer. This is an optimization done to avoid calculating the exp values
187     // twice making use of the output_data buffer as scratch memory.
188     int32_t sum_of_exps = 0;  // Q16.15 fixed point format.
189     int16_t* exp_results_Q015 = output_data + i * depth;
190     for (int c = 0; c < depth; ++c) {
191       exp_results_Q015[c] =
192           SoftMaxCalculateExp(params, input_data, depth, max_in_row, i, c);
193       sum_of_exps += exp_results_Q015[c];
194     }
195 
196     // Compute the reciprocal 1/sum_of_exps
197     uint8_t headroom_plus_one =
198         CountLeadingZeros(static_cast<uint32_t>(sum_of_exps));
199     int32_t shifted_sum =
200         ((static_cast<int64_t>(sum_of_exps) << (headroom_plus_one - 1)) +
201          (1 << 13)) >>
202         14;
203     // since the LUT computes 1/(1 + x) we need to first compute x = (sum - 1).
204     // also, the LUT expects a symmetrical input, so we must also recenter x
205     // from [0, 65535] to [-32768, 32767].
206     int32_t sym_shifted_sum = shifted_sum + (-((1 << 15) + (1 << 16)));
207     int16_t sat_sym_shifted_sum = static_cast<int16_t>(
208         std::min(std::max(sym_shifted_sum, static_cast<int32_t>(-32768)),
209                  static_cast<int32_t>(32767)));
210     // apply 1/(1 + x) LUT activation function
211     int16_t reciprocal_scale_Q015 =
212         lut_lookup(sat_sym_shifted_sum, params.one_over_one_plus_x_lut);
213 
214     // Rescale the exp_result with reciprocal
215     // range of output is [0, 32767] correspond to [0.0, 1.0]
216     for (int c = 0; c < depth; ++c) {
217       uint8_t right_shift = 31 - headroom_plus_one;
218       int64_t round = 1 << (right_shift - 1);
219       int32_t result = (static_cast<int64_t>(exp_results_Q015[c]) *
220                             static_cast<int64_t>(reciprocal_scale_Q015) +
221                         round) >>
222                        right_shift;
223       output_data[i * depth + c] = static_cast<int16_t>(
224           std::min(std::max(result, static_cast<int32_t>(0)),
225                    static_cast<int32_t>(32767)));
226     }
227   }
228 }
229 
230 }  // namespace reference_ops
231 }  // namespace tflite
232 
233 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
234