1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
18 
19 Licensed under the Apache License, Version 2.0 (the "License");
20 you may not use this file except in compliance with the License.
21 You may obtain a copy of the License at
22 
23     http://www.apache.org/licenses/LICENSE-2.0
24 
25 Unless required by applicable law or agreed to in writing, software
26 distributed under the License is distributed on an "AS IS" BASIS,
27 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28 See the License for the specific language governing permissions and
29 limitations under the License.
30 ==============================================================================*/
31 #include "tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h"
32 
33 #include <unordered_set>
34 #include <vector>
35 
36 #include "tensorflow/lite/kernels/kernel_util.h"
37 #include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
38 
39 namespace seq_flow_lite {
40 namespace ops {
41 namespace custom {
42 
43 namespace {
44 
45 const int kInputIndex = 0;
46 const int kScaleIndex = 1;
47 const int kOffsetIndex = 2;
48 const int kAxisIndex = 3;
49 const int kOutputIndex = 0;
50 
Resize(TfLiteContext * context,TfLiteNode * node)51 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
52   if (node->outputs->size != 1) {
53     return kTfLiteError;
54   }
55 
56   TfLiteTensor* input = &context->tensors[node->inputs->data[kInputIndex]];
57   TfLiteTensor* scale = &context->tensors[node->inputs->data[kScaleIndex]];
58   TfLiteTensor* offset = &context->tensors[node->inputs->data[kOffsetIndex]];
59   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteUInt8);
60   TF_LITE_ENSURE_EQ(context, offset->dims->data[0], 1);
61   TF_LITE_ENSURE_EQ(context, offset->dims->size, 1);
62   TF_LITE_ENSURE_EQ(context, offset->type, kTfLiteUInt8);
63   TF_LITE_ENSURE_EQ(context, scale->dims->data[0], 1);
64   TF_LITE_ENSURE_EQ(context, scale->dims->size, 1);
65   TF_LITE_ENSURE_EQ(context, scale->type, kTfLiteUInt8);
66   if (node->inputs->size == 4) {
67     TfLiteTensor* axis = &context->tensors[node->inputs->data[kAxisIndex]];
68     TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
69   }
70 
71   TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputIndex]];
72   TF_LITE_ENSURE_EQ(context, output->type, kTfLiteUInt8);
73   return context->ResizeTensor(context, output,
74                                TfLiteIntArrayCopy(input->dims));
75 }
76 
GetNumberOfSteps(const TfLiteTensor * input)77 int GetNumberOfSteps(const TfLiteTensor* input) {
78   int number_of_steps = 1;
79   for (int i = 0; i < input->dims->size; ++i) {
80     number_of_steps *= input->dims->data[i];
81   }
82   return number_of_steps;
83 }
84 
GetNumberOfFeatures(const TfLiteTensor * input,const int * axis,const int num_axis)85 inline int GetNumberOfFeatures(const TfLiteTensor* input, const int* axis,
86                                const int num_axis) {
87   int num_features = 1;
88   for (int i = 0; i < num_axis; ++i) {
89     num_features *= input->dims->data[axis[i]];
90   }
91   return num_features;
92 }
93 
94 // Performs sanity checks on input axis and resolves into valid dimensions.
ResolveAxis(const int num_dims,const int * axis,const int num_axis,int * out_axis,int * out_num_axis)95 inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis,
96                         int* out_axis, int* out_num_axis) {
97   *out_num_axis = 0;
98   // Short-circuit axis resolution for scalars; the axis will go unused.
99   if (num_dims == 0) {
100     return true;
101   }
102 
103   // Using an unordered set to reduce complexity in looking up duplicates.
104   std::unordered_set<int> unique_indices;
105   for (int64_t idx = 0; idx < num_axis; ++idx) {
106     // Handle negative index.
107     int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
108     assert(current >= 0 && current < num_dims);
109     // Only adding the axis if it wasn't added before.
110     if (unique_indices.find(current) == unique_indices.end()) {
111       unique_indices.insert(current);
112       out_axis[*out_num_axis] = current;
113       *out_num_axis += 1;
114     }
115   }
116   return true;
117 }
118 
119 // Given current position in the input array, the api computes the next valid
120 // index.
ValidIndex(const int * input_dims,const int input_dims_size,int * curr_pos)121 bool ValidIndex(const int* input_dims, const int input_dims_size,
122                 int* curr_pos) {
123   if (input_dims_size == 0) {
124     return false;
125   }
126   assert(input_dims != nullptr);
127   assert(curr_pos != nullptr);
128   for (int idx = input_dims_size - 1; idx >= 0; --idx) {
129     int current_val = curr_pos[idx] + 1;
130     assert(input_dims[idx] >= current_val);
131     if (input_dims[idx] == current_val) {
132       curr_pos[idx] = 0;
133     } else {
134       curr_pos[idx] = current_val;
135       return true;
136     }
137   }
138   return false;
139 }
140 
141 // Gets next offset depending on reduction axis. Implementation borrowed from
142 // tflite reduce mean implementation.
GetOffset(const int * input_dims,const int input_dims_size,const int * curr_pos,const int * axis,const int axis_size)143 int GetOffset(const int* input_dims, const int input_dims_size,
144               const int* curr_pos, const int* axis, const int axis_size) {
145   if (input_dims_size == 0) return 0;
146   assert(input_dims != nullptr);
147   assert(curr_pos != nullptr);
148   int offset = 0;
149   for (int idx = 0; idx < input_dims_size; ++idx) {
150     // if idx is part of reduction axes, we skip offset calculation.
151     bool is_axis = false;
152     if (axis != nullptr) {
153       for (int redux = 0; redux < axis_size; ++redux) {
154         if (idx == axis[redux]) {
155           is_axis = true;
156           break;
157         }
158       }
159     }
160     if (!is_axis) offset = offset * input_dims[idx] + curr_pos[idx];
161   }
162 
163   return offset;
164 }
165 
166 // TODO(b/132896827): Current implementation needs further evaluation to reduce
167 // space time complexities.
FlexibleLayerNorm(const TfLiteTensor * input,const float scale,const float offset,const int * axis,const int num_axis,TfLiteTensor * output)168 TfLiteStatus FlexibleLayerNorm(const TfLiteTensor* input, const float scale,
169                                const float offset, const int* axis,
170                                const int num_axis, TfLiteTensor* output) {
171   int num_features = GetNumberOfFeatures(input, &axis[0], num_axis);
172   int time_steps = static_cast<int>(GetNumberOfSteps(input) / num_features);
173 
174   std::vector<float> sum_x(time_steps, 0.0f);
175   std::vector<float> sum_xx(time_steps, 0.0f);
176   std::vector<int> index_iter(input->dims->size, 0);
177 
178   // Computing sum and squared sum for features across the reduction axes.
179   do {
180     // Not passing reduction axes to get the input offset as we are simply
181     // iterating through the multidimensional array.
182     int input_offset = GetOffset(input->dims->data, input->dims->size,
183                                  &index_iter[0], nullptr, 0);
184     // Passing in the valid reduction axes as we would like to get the output
185     // offset after reduction.
186     int stats_offset = GetOffset(input->dims->data, input->dims->size,
187                                  &index_iter[0], &axis[0], num_axis);
188     float input_val = PodDequantize(*input, input_offset);
189     sum_x[stats_offset] += input_val;
190     sum_xx[stats_offset] += input_val * input_val;
191   } while (ValidIndex(input->dims->data, input->dims->size, &index_iter[0]));
192 
193   std::vector<float> multiplier(time_steps, 1.0f);
194   std::vector<float> bias(time_steps, 0.0f);
195 
196   // Computing stats for the reduction axes.
197   for (int i = 0; i < time_steps; ++i) {
198     sum_x[i] = sum_x[i] / num_features;
199     sum_xx[i] = sum_xx[i] / num_features;
200     const float variance = sum_xx[i] - sum_x[i] * sum_x[i];
201     const float inverse_stddev = 1 / sqrt(variance + 1e-6);
202     multiplier[i] = inverse_stddev * scale;
203     bias[i] = offset - sum_x[i] * inverse_stddev * scale;
204   }
205 
206   const float out_inverse_scale = 1.0f / output->params.scale;
207   const int32_t out_zero_point = output->params.zero_point;
208   uint8_t* out_ptr = output->data.uint8;
209   std::fill(index_iter.begin(), index_iter.end(), 0);
210 
211   // Using the stats to fill the output pointer.
212   do {
213     // Not passing reduction axes to get the input offset as we are simply
214     // iterating through the multidimensional array.
215     int input_offset = GetOffset(input->dims->data, input->dims->size,
216                                  &index_iter[0], nullptr, 0);
217     // Passing in the valid reduction axes as we would like to get the output
218     // offset after reduction.
219     int stats_offset = GetOffset(input->dims->data, input->dims->size,
220                                  &index_iter[0], &axis[0], num_axis);
221     float input_val = PodDequantize(*input, input_offset);
222 
223     const float value =
224         input_val * multiplier[stats_offset] + bias[stats_offset];
225     out_ptr[input_offset] =
226         PodQuantize(value, out_zero_point, out_inverse_scale);
227   } while (ValidIndex(input->dims->data, input->dims->size, &index_iter[0]));
228 
229   return kTfLiteOk;
230 }
231 
232 /*
233  * Layer normalization is optimized as follows in integer arithmetic
234  *
235  * Algorithm
236  * *********
237  * Subscript i \in {1, ..., N}, Inputs q_i, Outputs oq_i.
238  *
239  * x_i = (q_i - input_zero_point) * input_scale
240  * mean = sum_i x_i / N
241  * var = sum_i (x_i * x_i / N) - mean * mean
242  * std = sqrt(var + tolerance)
243  * xni = (xi - mean) / std
244  * yi = xni * scale + offset
245  * o_i = round(y_i / output_scale + output_zero_point)
246  * oq_i = clamp(o_i, 0, 255)
247  *
248  * Optimizations
249  * *************
250  * Applying linear expansion
251  * x_i = q_i * input_scale - input_zero_point * input_scale
252  * or x_i = m * qi + c
253  * mean = m * mean_q + c
254  * Variance is not affected by a constant shift to input
255  * var = m^2 * var_q
256  * std = m * sqrt(var_q + tolerance)
257  * Expanding xi, mean, std in equation for xni
258  * xni = (m * qi + c - m * mean_q - c) / m * sqrt(var_q + tolerance)
259  * Simplifying
260  * xni = (qi - mean_q) / sqrt(var_q + tolerance)
261  * Setting inv_std_qi = 1 / sqrt(var_q + tolerance)
262  * xni = qi * inv_std_qi - mean_q * inv_std_qi
263  * yi = qi * inv_std_qi * scale - mean_q * inv_std_qi * scale + offset
264  * o_i = round(qi * inv_std_qi * scale / output_scale
265  *             - mean_q * inv_std_qi * scale / output_scale
266  *             + offset / output_scale
267  *             + output_zero_point)
268  * Setting
269  * static_bias = offset / output_scale + output_zero_point
270  * static_scale = scale / output_scale
271  * o_i = round(qi * inv_std_qi * static_scale
272  *             - mean_q * inv_std_qi * static_scale
273  *             + static_bias)
274  * Setting
275  * dynamic_scale = inv_std_qi * static_scale
276  * dynamic_bias = static_bias - mean_q * dynamic_scale
277  * o_i = round(qi * dynamic_scale + dynamic_bias)
278  * oq_i = clamp(round(qi * dynamic_scale + dynamic_bias), 0, 255)
279  *
280  * This results in the below optimized implementation. The strategy is to first
281  * compute first and second order summary statistics for qi in a loop,
282  * then compute mean_q, var_q and then dynamic_scale/dynamic_bias. This
283  * allows one to compute oqi quickly in a tight loop.
284  * */
IntegerLayerNorm(const TfLiteTensor * input,const float scale,const float offset,TfLiteTensor * output)285 TfLiteStatus IntegerLayerNorm(const TfLiteTensor* input, const float scale,
286                               const float offset, TfLiteTensor* output) {
287   const int input_rank = input->dims->size;
288   const int num_features = input->dims->data[input_rank - 1];
289   const int time_steps =
290       static_cast<int>(GetNumberOfSteps(input) / num_features);
291 
292   const float out_inverse_scale = 1.0f / output->params.scale;
293   const float static_scale = scale * out_inverse_scale;
294   const float static_bias = static_cast<float>(output->params.zero_point) +
295                             offset * out_inverse_scale;
296   const float inverse_num_features = 1.0f / num_features;
297   const uint8_t* const in_ptr = input->data.uint8;
298   uint8_t* out_ptr = output->data.uint8;
299   for (int i = 0; i < time_steps; ++i) {
300     int32_t i32_sum_q = 0;
301     int32_t i32_sum_qq = 0;
302     const int32_t index = i * num_features;
303     for (int j = index; j < index + num_features; ++j) {
304       const int32_t q_i = static_cast<int32_t>(in_ptr[j]);
305       // Compute first and second order statistics for qi.
306       i32_sum_q += q_i;
307       i32_sum_qq += q_i * q_i;
308     }
309     const float second_moment_qq = i32_sum_qq * inverse_num_features;
310     const float mean_q = i32_sum_q * inverse_num_features;
311     const float var_q = second_moment_qq - mean_q * mean_q;
312     const float inv_std_q = 1.0f / sqrt(var_q + 1e-6);
313     const float dynamic_scale = inv_std_q * static_scale;
314     const float dynamic_bias = static_bias - mean_q * dynamic_scale;
315     for (int j = index; j < index + num_features; ++j) {
316       const int32_t invalue = static_cast<int32_t>(in_ptr[j]);
317       const float value = invalue * dynamic_scale + dynamic_bias;
318       // Use an offseted cast to perform float round.
319       const int32_t i32value =
320           static_cast<int32_t>(value + ((value >= 0.0) ? 0.5f : -0.5f));
321       // Clamp the result.
322       out_ptr[j] = static_cast<uint8_t>(std::max(std::min(255, i32value), 0));
323     }
324   }
325   return kTfLiteOk;
326 }
327 
DefaultLayerNormFloat(const TfLiteTensor * input,const float scale,const float offset,TfLiteTensor * output)328 TfLiteStatus DefaultLayerNormFloat(const TfLiteTensor* input, const float scale,
329                                    const float offset, TfLiteTensor* output) {
330   const int input_rank = input->dims->size;
331   const int num_features = input->dims->data[input_rank - 1];
332   const int time_steps =
333       static_cast<int>(GetNumberOfSteps(input) / num_features);
334   float* out_ptr = output->data.f;
335   for (int i = 0; i < time_steps; ++i) {
336     float sum_x = 0;
337     float sum_xx = 0;
338     for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
339       sum_x += input->data.f[index];
340       sum_xx += input->data.f[index] * input->data.f[index];
341     }
342     const float exp_xx = sum_xx / num_features;
343     const float exp_x = sum_x / num_features;
344     const float variance = exp_xx - exp_x * exp_x;
345     const float inverse_stddev = 1 / sqrt(variance + 1e-6);
346     const float multiplier = inverse_stddev * scale;
347 
348     const float bias = offset - exp_x * inverse_stddev * scale;
349     for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
350       out_ptr[index] = input->data.f[index] * multiplier + bias;
351     }
352   }
353   return kTfLiteOk;
354 }
355 
DefaultLayerNorm(const TfLiteTensor * input,const float scale,const float offset,TfLiteTensor * output)356 TfLiteStatus DefaultLayerNorm(const TfLiteTensor* input, const float scale,
357                               const float offset, TfLiteTensor* output) {
358   const int input_rank = input->dims->size;
359   const int num_features = input->dims->data[input_rank - 1];
360   const int time_steps =
361       static_cast<int>(GetNumberOfSteps(input) / num_features);
362 
363   std::vector<float> temp_buffer(num_features, 0.0f);
364   const float out_inverse_scale = 1.0f / output->params.scale;
365   const int32_t out_zero_point = output->params.zero_point;
366   uint8_t* out_ptr = output->data.uint8;
367   for (int i = 0; i < time_steps; ++i) {
368     float sum_x = 0;
369     float sum_xx = 0;
370     for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
371       temp_buffer[j] = PodDequantize(*input, index);
372       sum_x += temp_buffer[j];
373       sum_xx += temp_buffer[j] * temp_buffer[j];
374     }
375     const float exp_xx = sum_xx / num_features;
376     const float exp_x = sum_x / num_features;
377     const float variance = exp_xx - exp_x * exp_x;
378     const float inverse_stddev = 1 / sqrt(variance + 1e-6);
379     const float multiplier = inverse_stddev * scale;
380     const float bias = offset - exp_x * inverse_stddev * scale;
381     for (int j = 0, index = i * num_features; j < num_features; ++j, ++index) {
382       const float value = temp_buffer[j] * multiplier + bias;
383       out_ptr[index] = PodQuantize(value, out_zero_point, out_inverse_scale);
384     }
385   }
386   return kTfLiteOk;
387 }
388 
Eval(TfLiteContext * context,TfLiteNode * node)389 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
390   const TfLiteTensor* input =
391       &context->tensors[node->inputs->data[kInputIndex]];
392   TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputIndex]];
393   TfLiteTensor scale_tensor = context->tensors[node->inputs->data[kScaleIndex]];
394   TfLiteTensor offset_tensor =
395       context->tensors[node->inputs->data[kOffsetIndex]];
396   float scale = 1.0;
397   float offset = 0.0;
398   if (input->type == kTfLiteUInt8) {
399     scale = PodDequantize(scale_tensor, 0);
400     offset = PodDequantize(offset_tensor, 0);
401   } else {
402     scale = scale_tensor.data.f[0];
403     offset = offset_tensor.data.f[0];
404   }
405 
406   TfLiteTensor* axis = &context->tensors[node->inputs->data[kAxisIndex]];
407   int num_axis = static_cast<int>(tflite::NumElements(axis));
408   // For backward compatibility reasons, we handle the default layer norm for
409   // last channel as below.
410   if (num_axis == 1 && (axis->data.i32[0] == -1 ||
411                         axis->data.i32[0] == (input->dims->size - 1))) {
412     if (input->type == kTfLiteUInt8) {
413       return IntegerLayerNorm(input, scale, offset, output);
414     } else if (input->type == kTfLiteFloat32) {
415       return DefaultLayerNormFloat(input, scale, offset, output);
416     } else {
417       TF_LITE_ENSURE_MSG(context, false,
418                          "Input should be eith Uint8 or Float32.");
419     }
420   }
421 
422   std::vector<int> resolved_axis(num_axis);
423   // Resolve axis.
424   int num_resolved_axis = 0;
425   if (!ResolveAxis(input->dims->size, axis->data.i32, num_axis,
426                    &resolved_axis[0], &num_resolved_axis)) {
427     return kTfLiteError;
428   }
429 
430   return FlexibleLayerNorm(input, scale, offset, &resolved_axis[0],
431                            num_resolved_axis, output);
432 }
433 
434 }  // namespace
435 
Register_LAYER_NORM()436 TfLiteRegistration* Register_LAYER_NORM() {
437   static TfLiteRegistration r = {nullptr, nullptr, Resize, Eval};
438   return &r;
439 }
440 
441 }  // namespace custom
442 }  // namespace ops
443 }  // namespace seq_flow_lite
444