1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
17
18 #include <algorithm>
19 #include <limits>
20
21 #include "fixedpoint/fixedpoint.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/cppmath.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/op_macros.h"
27
28 namespace tflite {
29 namespace reference_ops {
30
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)31 inline void Softmax(const SoftmaxParams& params,
32 const RuntimeShape& input_shape, const float* input_data,
33 const RuntimeShape& output_shape, float* output_data) {
34 const int trailing_dim = input_shape.DimensionsCount() - 1;
35 const int outer_size =
36 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
37 const int depth =
38 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
39
40 for (int i = 0; i < outer_size; ++i) {
41 // Find max element value which we'll use to ensure numerical stability
42 // taking advantage of the following equality:
43 // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
44 float max = std::numeric_limits<float>::lowest();
45 for (int c = 0; c < depth; ++c) {
46 max = std::max(max, input_data[i * depth + c]);
47 }
48
49 // Compute sum.
50 float sum = 0.f;
51 for (int c = 0; c < depth; ++c) {
52 const float exp_c = std::exp((input_data[i * depth + c] - max) *
53 static_cast<float>(params.beta));
54 output_data[i * depth + c] = exp_c;
55 sum += exp_c;
56 }
57
58 // Compute result.
59 for (int c = 0; c < depth; ++c) {
60 output_data[i * depth + c] = output_data[i * depth + c] / sum;
61 }
62 }
63 }
64
65 // Quantized softmax with int8_t/uint8_t input and int8_t/uint8_t/int16_t
66 // output.
67 template <typename InputT, typename OutputT>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const InputT * input_data,const RuntimeShape & output_shape,OutputT * output_data)68 inline void Softmax(const SoftmaxParams& params,
69 const RuntimeShape& input_shape, const InputT* input_data,
70 const RuntimeShape& output_shape, OutputT* output_data) {
71 const int32_t input_beta_multiplier = params.input_multiplier;
72 const int32_t input_beta_left_shift = params.input_left_shift;
73 const int diff_min = params.diff_min;
74 // The representation chosen for the input to the exp() function is Q5.26.
75 // We need to leave extra space since values that we skip might be as large as
76 // -32 before multiplying by input_beta_multiplier, and therefore as large as
77 // -16 afterwards. Note that exp(-8) is definitely not insignificant to
78 // accumulation, but exp(-16) definitely is.
79 static const int kScaledDiffIntegerBits = 5;
80 static const int kAccumulationIntegerBits = 12;
81 using FixedPointScaledDiff =
82 gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
83 using FixedPointAccum =
84 gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
85 using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
86
87 const int trailing_dim = input_shape.DimensionsCount() - 1;
88 const int outer_size =
89 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
90 const int depth =
91 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
92
93 for (int i = 0; i < outer_size; ++i) {
94 InputT max_in_row = std::numeric_limits<InputT>::min();
95 for (int c = 0; c < depth; ++c) {
96 max_in_row = std::max(max_in_row, input_data[i * depth + c]);
97 }
98
99 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
100 for (int c = 0; c < depth; ++c) {
101 int32_t input_diff =
102 static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
103 if (input_diff >= diff_min) {
104 const int32_t input_diff_rescaled =
105 MultiplyByQuantizedMultiplierGreaterThanOne(
106 input_diff, input_beta_multiplier, input_beta_left_shift);
107 const FixedPointScaledDiff scaled_diff_f8 =
108 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
109 sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
110 exp_on_negative_values(scaled_diff_f8));
111 }
112 }
113
114 int num_bits_over_unit;
115 FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
116 sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
117
118 for (int c = 0; c < depth; ++c) {
119 int32_t input_diff =
120 static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
121 if (input_diff >= diff_min) {
122 const int32_t input_diff_rescaled =
123 MultiplyByQuantizedMultiplierGreaterThanOne(
124 input_diff, input_beta_multiplier, input_beta_left_shift);
125 const FixedPointScaledDiff scaled_diff_f8 =
126 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
127
128 FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
129 int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
130 (shifted_scale * exp_in_0).raw(),
131 num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
132
133 const int32_t shifted_output =
134 unsat_output +
135 static_cast<int32_t>(std::numeric_limits<OutputT>::min());
136
137 output_data[i * depth + c] = static_cast<OutputT>(std::max(
138 std::min(shifted_output,
139 static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
140 static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
141 } else {
142 output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
143 }
144 }
145 }
146 }
147
148 // Computes exp(input - max_input)
SoftMaxCalculateExp(const SoftmaxParams & params,const int16_t * input_data,const int depth,int16_t max_in_row,int i,int c)149 inline int16_t SoftMaxCalculateExp(const SoftmaxParams& params,
150 const int16_t* input_data, const int depth,
151 int16_t max_in_row, int i, int c) {
152 int32_t input_diff = input_data[i * depth + c] - max_in_row;
153 // scale the input_diff such that [-65535, 0] correspond to [-10.0, 0.0]
154 // exp lut generated with range [-10, 0], as exp(-10) is negligible.
155 int32_t scaled_diff = MultiplyByQuantizedMultiplier(
156 input_diff, params.input_multiplier, params.input_left_shift);
157 // recenter to [-32768, 32767]
158 int32_t sym_scaled_diff = scaled_diff + 32767;
159 int16_t sat_sym_scaled_diff =
160 std::min(std::max(sym_scaled_diff, static_cast<int32_t>(-32768)),
161 static_cast<int32_t>(32767));
162 // apply the exp() LUT activation function
163 return lut_lookup(sat_sym_scaled_diff, params.exp_lut);
164 }
165 // Quantized softmax with int16_t input and int16_t output.
SoftmaxInt16(const SoftmaxParams & params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,int16_t * output_data)166 inline void SoftmaxInt16(const SoftmaxParams& params,
167 const RuntimeShape& input_shape,
168 const int16_t* input_data,
169 const RuntimeShape& output_shape,
170 int16_t* output_data) {
171 const int trailing_dim = input_shape.DimensionsCount() - 1;
172 const int outer_size =
173 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
174 const int depth =
175 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
176
177 for (int i = 0; i < outer_size; ++i) {
178 // Find the largest element
179 int16_t max_in_row = std::numeric_limits<int16_t>::min();
180 for (int c = 0; c < depth; ++c) {
181 max_in_row = std::max(max_in_row, input_data[i * depth + c]);
182 }
183
184 // This loops computes the exp values and their sum. We will need the exp
185 // values later on in the function so we cache them in the output_data
186 // buffer. This is an optimization done to avoid calculating the exp values
187 // twice making use of the output_data buffer as scratch memory.
188 int32_t sum_of_exps = 0; // Q16.15 fixed point format.
189 int16_t* exp_results_Q015 = output_data + i * depth;
190 for (int c = 0; c < depth; ++c) {
191 exp_results_Q015[c] =
192 SoftMaxCalculateExp(params, input_data, depth, max_in_row, i, c);
193 sum_of_exps += exp_results_Q015[c];
194 }
195
196 // Compute the reciprocal 1/sum_of_exps
197 uint8_t headroom_plus_one =
198 CountLeadingZeros(static_cast<uint32_t>(sum_of_exps));
199 int32_t shifted_sum =
200 ((static_cast<int64_t>(sum_of_exps) << (headroom_plus_one - 1)) +
201 (1 << 13)) >>
202 14;
203 // since the LUT computes 1/(1 + x) we need to first compute x = (sum - 1).
204 // also, the LUT expects a symmetrical input, so we must also recenter x
205 // from [0, 65535] to [-32768, 32767].
206 int32_t sym_shifted_sum = shifted_sum + (-((1 << 15) + (1 << 16)));
207 int16_t sat_sym_shifted_sum = static_cast<int16_t>(
208 std::min(std::max(sym_shifted_sum, static_cast<int32_t>(-32768)),
209 static_cast<int32_t>(32767)));
210 // apply 1/(1 + x) LUT activation function
211 int16_t reciprocal_scale_Q015 =
212 lut_lookup(sat_sym_shifted_sum, params.one_over_one_plus_x_lut);
213
214 // Rescale the exp_result with reciprocal
215 // range of output is [0, 32767] correspond to [0.0, 1.0]
216 for (int c = 0; c < depth; ++c) {
217 uint8_t right_shift = 31 - headroom_plus_one;
218 int64_t round = 1 << (right_shift - 1);
219 int32_t result = (static_cast<int64_t>(exp_results_Q015[c]) *
220 static_cast<int64_t>(reciprocal_scale_Q015) +
221 round) >>
222 right_shift;
223 output_data[i * depth + c] = static_cast<int16_t>(
224 std::min(std::max(result, static_cast<int32_t>(0)),
225 static_cast<int32_t>(32767)));
226 }
227 }
228 }
229
230 } // namespace reference_ops
231 } // namespace tflite
232
233 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
234