xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <limits>
20 #include <vector>
21 
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h"
25 #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h"
26 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
27 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
28 #include "tensorflow/lite/kernels/internal/types.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 
31 namespace tflite {
32 namespace delegates {
33 namespace hexagon {
34 namespace {
35 
36 constexpr uint8_t k8BitSignFlipConstant = 0x80;
37 // 1/1024 ~ 0.0009766 is a restriction set by Hexagon's kernels.
38 // TODO(b/151103818): Figure out a way to retrieve this constant reliably.
39 constexpr float kHexagonMinRelativeScale = 0.0009766f;
40 // Channel count to split depthwise convolution op.
41 // See Conv2dOpBuilder.should_split_dwconv_ for details.
42 constexpr int kDwConv5x5Filt2x2StrideChannelCount = 32;
43 
44 }  // namespace
45 
ProcessPerChannelQuantizedWeights(const TfLiteTensor & weights_tensor,TfLiteContext * context,float * weights_min,float * weights_max,GraphBuilder * graph_builder,PerChannelQuantData * per_channel_quant)46 TfLiteStatus ProcessPerChannelQuantizedWeights(
47     const TfLiteTensor& weights_tensor, TfLiteContext* context,
48     float* weights_min, float* weights_max, GraphBuilder* graph_builder,
49     PerChannelQuantData* per_channel_quant) {
50   if (!per_channel_quant) return kTfLiteError;
51   TfLiteAffineQuantization* weights_quant_params =
52       reinterpret_cast<TfLiteAffineQuantization*>(
53           weights_tensor.quantization.params);
54 
55   // Retrieve channel scales.
56   per_channel_quant->num_scale_values = weights_quant_params->scale->size;
57   // Normalize the scales as expected by Hexagon.
58   per_channel_quant->scales_data = weights_quant_params->scale->data;
59   std::vector<float> normalized_scales;
60   normalized_scales.reserve(per_channel_quant->num_scale_values);
61   float scale_max = 0.0;
62   for (int i = 0; i < per_channel_quant->num_scale_values; ++i) {
63     normalized_scales.push_back(per_channel_quant->scales_data[i]);
64     if (per_channel_quant->scales_data[i] > scale_max) {
65       scale_max = per_channel_quant->scales_data[i];
66     }
67   }
68   if (scale_max == 0.0) {
69     TF_LITE_KERNEL_LOG(context, "Scale max is zero for: %s",
70                        weights_tensor.name);
71     return kTfLiteError;
72   }
73   for (int i = 0; i < per_channel_quant->num_scale_values; ++i) {
74     normalized_scales[i] =
75         std::max(normalized_scales[i] / scale_max, kHexagonMinRelativeScale);
76   }
77   // Add node for channel scales data.
78   const std::vector<int> scales_shape = {1, 1, 1,
79                                          per_channel_quant->num_scale_values};
80   per_channel_quant->channel_scales_node = graph_builder->AddConstNodeWithData(
81       scales_shape.data(), reinterpret_cast<char*>(normalized_scales.data()),
82       normalized_scales.size() * sizeof(normalized_scales[0]));
83   if (per_channel_quant->splits) {
84     // Split channel scales to 32 channel batches.
85     const std::vector<int> sliced_scales_shape = {
86         1, 1, 1, kDwConv5x5Filt2x2StrideChannelCount};
87     for (auto i = 0; i < per_channel_quant->splits; ++i) {
88       auto offset = kDwConv5x5Filt2x2StrideChannelCount * i;
89       auto* node = graph_builder->AddConstNodeWithData(
90           sliced_scales_shape.data(),
91           reinterpret_cast<char*>(normalized_scales.data() + offset),
92           kDwConv5x5Filt2x2StrideChannelCount * sizeof(normalized_scales[0]));
93       per_channel_quant->channel_scales_nodes.push_back(node);
94     }
95   }
96   *weights_min = -128 * scale_max;
97   *weights_max = 127 * scale_max;
98   return kTfLiteOk;
99 }
100 
ProcessPerChannelQuantizedBias(const TfLiteTensor & data_tensor,const TfLiteTensor & bias_tensor,const int bias_tensor_idx,TfLiteContext * context,float * bias_min,float * bias_max,GraphBuilder * graph_builder,PerChannelQuantData * per_channel_quant,std::vector<int> * preprocessed_bias_data,OpBuilder ** bias_const_node)101 TfLiteStatus ProcessPerChannelQuantizedBias(
102     const TfLiteTensor& data_tensor, const TfLiteTensor& bias_tensor,
103     const int bias_tensor_idx, TfLiteContext* context, float* bias_min,
104     float* bias_max, GraphBuilder* graph_builder,
105     PerChannelQuantData* per_channel_quant,
106     std::vector<int>* preprocessed_bias_data, OpBuilder** bias_const_node) {
107   const TfLiteAffineQuantization* input_quant_params =
108       static_cast<const TfLiteAffineQuantization*>(
109           data_tensor.quantization.params);
110   const float input_scale = input_quant_params->scale->data[0];
111   // Now dequantize bias values to float first, to adjust for the
112   // normalization of channel scales.
113   auto* bias_data = bias_tensor.data.i32;
114   const int bias_size = NumElements(&bias_tensor);
115   if (bias_size != per_channel_quant->num_scale_values) {
116     TF_LITE_KERNEL_LOG(
117         context, "Bias/channel scales number mismatch for bias tensor: %s",
118         bias_tensor.name);
119     return kTfLiteError;
120   }
121   std::vector<float> dequantized_bias;
122   dequantized_bias.reserve(bias_size);
123   for (int i = 0; i < bias_size; ++i) {
124     const float dequantized_value =
125         bias_data[i] * input_scale * per_channel_quant->scales_data[i];
126     const float abs_dequantized_value = std::abs(dequantized_value);
127     if (abs_dequantized_value > *bias_max) {
128       *bias_max = abs_dequantized_value;
129     }
130     dequantized_bias.push_back(dequantized_value);
131   }
132   *bias_max = *bias_max * 8;
133   *bias_min = -1 * *bias_max;
134   // Now requantize the bias values to the new min/max values.
135   preprocessed_bias_data->reserve(per_channel_quant->num_scale_values);
136   for (int i = 0; i < bias_size; ++i) {
137     preprocessed_bias_data->push_back(static_cast<int>(
138         std::round(std::pow(2, 31) * (dequantized_bias[i] / *bias_max))));
139   }
140   // Add nodes for bias.
141   const std::vector<int> bias_shape = {1, 1, 1, bias_size};
142   auto* bias_data_node = graph_builder->AddConstNodeWithData(
143       bias_shape.data(),
144       reinterpret_cast<char*>(preprocessed_bias_data->data()),
145       preprocessed_bias_data->size() * sizeof((*preprocessed_bias_data)[0]));
146   if (bias_const_node) {
147     *bias_const_node = bias_data_node;
148   }
149   graph_builder->AddTensorWithID(bias_tensor_idx, bias_data_node->GetID(), 0,
150                                  /*overwrite=*/true);
151   return kTfLiteOk;
152 }
153 
CheckShouldSplitDwConv(TfLiteType weights_type,int input_depth,bool is_per_channel_quant,int channel_multiplier)154 void Conv2dOpBuilder::CheckShouldSplitDwConv(TfLiteType weights_type,
155                                              int input_depth,
156                                              bool is_per_channel_quant,
157                                              int channel_multiplier) {
158   const TfLiteDepthwiseConvParams* conv_params =
159       reinterpret_cast<const TfLiteDepthwiseConvParams*>(builtin_data_);
160   int weights_height = weight_shape_[0];
161   int weights_width = weight_shape_[1];
162   // input_depth * channel_multiplier
163   int weights_depth_size = input_depth * weight_shape_[3];
164   if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8 &&
165       weights_type == kTfLiteInt8 &&
166       // weight_shape_ is [fh,fw,din,dmul]
167       weights_height == 5 && weights_width == 5 &&
168       // Stride larger than 2x2
169       conv_params->stride_height >= 2 && conv_params->stride_width >= 2 &&
170       // Depth more than 32 and is multiples of 32 so can be splitted.
171       input_depth > kDwConv5x5Filt2x2StrideChannelCount &&
172       input_depth % kDwConv5x5Filt2x2StrideChannelCount == 0 &&
173       is_per_channel_quant && channel_multiplier == 1) {
174     should_split_dwconv_ = true;
175     // Splits the inputs to 32 channel batches.
176     per_channel_quant_.splits =
177         weights_depth_size / kDwConv5x5Filt2x2StrideChannelCount;
178     per_channel_quant_.channel_scales_nodes.reserve(per_channel_quant_.splits);
179   }
180 }
181 
SplitWeightsForDwConv(const std::vector<uint8_t> & converted_data,int input_depth,int channel_multiplier)182 void Conv2dOpBuilder::SplitWeightsForDwConv(
183     const std::vector<uint8_t>& converted_data, int input_depth,
184     int channel_multiplier) {
185   int weights_height_size = weight_shape_[0];
186   int weights_width_size = weight_shape_[1];
187   // Split the weight tensor into 32 channel batches.
188   SplitParams split_params{
189       .num_split = static_cast<uint16_t>(per_channel_quant_.splits),
190       .axis = 2,
191   };
192   std::vector<RuntimeShape> split_shapes;
193   std::vector<const tflite::RuntimeShape*> split_shapes_data;
194   std::vector<std::vector<uint8_t>> splitted_weights;
195   std::vector<uint8_t*> splitted_weights_data;
196   split_shapes.reserve(per_channel_quant_.splits);
197   split_shapes_data.reserve(per_channel_quant_.splits);
198   splitted_weights.reserve(per_channel_quant_.splits);
199   splitted_weights_data.reserve(per_channel_quant_.splits);
200   for (auto s = 0; s < per_channel_quant_.splits; s++) {
201     split_shapes.push_back({weights_height_size, weights_width_size,
202                             kDwConv5x5Filt2x2StrideChannelCount,
203                             channel_multiplier});
204     split_shapes_data.push_back(&split_shapes.back());
205     splitted_weights.emplace_back(weights_height_size * weights_width_size *
206                                       channel_multiplier *
207                                       kDwConv5x5Filt2x2StrideChannelCount,
208                                   0);
209     splitted_weights_data.push_back(splitted_weights.back().data());
210   }
211   RuntimeShape weight_shape = {weights_height_size, weights_width_size,
212                                input_depth, channel_multiplier};
213   optimized_ops::Split(split_params, weight_shape, converted_data.data(),
214                        split_shapes_data.data(), splitted_weights_data.data());
215   for (auto s = 0; s < per_channel_quant_.splits; s++) {
216     auto splitted_weights_node = graph_builder_->AddConstNodeWithData(
217         split_shapes[s].DimsData(),
218         reinterpret_cast<char*>(splitted_weights_data[s]),
219         splitted_weights[s].size() * sizeof(splitted_weights_data[s][0]));
220     weights_nodes_.push_back(splitted_weights_node);
221   }
222 }
223 
InitializeWeightsNodes(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context,const int input_depth)224 TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
225     const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
226     TfLiteContext* context, const int input_depth) {
227   const std::vector<int> quant_bound_shape = {1, 1, 1, 1};
228 
229   const auto& weights_tensor = context->tensors[inputs->data[1]];
230   if (weights_tensor.allocation_type != kTfLiteMmapRo) {
231     TF_LITE_KERNEL_LOG(
232         context, "Weights tensor doesn't have correct allocation type: %s",
233         weights_tensor.name);
234     return kTfLiteError;
235   }
236   int weights_batch_size, weights_height_size, weights_width_size,
237       weights_depth_size;
238   // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
239   // Transpose NHWC -> HWCN
240   GetDims(&weights_batch_size, &weights_height_size, &weights_width_size,
241           &weights_depth_size, weights_tensor.dims);
242 
243   // Weights tensor could be int8 even for per-tensor quantization.
244   // Therefore, we look at the number of scale values to check if it is
245   // per-channel quantized.
246   TfLiteAffineQuantization* weights_quant_params =
247       reinterpret_cast<TfLiteAffineQuantization*>(
248           weights_tensor.quantization.params);
249   const bool is_per_channel_quant = weights_quant_params->scale->size > 1;
250 
251   // WEIGHTS DATA.
252   OpBuilder* weights_data_node = nullptr;
253   if (op_node_.op_type == OP_Supernode_8x8p32to8) {
254     // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
255     // Transpose NHWC -> HWCN
256     weight_shape_ = {weights_height_size, weights_width_size,
257                      weights_depth_size, weights_batch_size};
258     RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
259                              weights_width_size, weights_depth_size});
260     RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
261                              weights_depth_size, weights_batch_size});
262     std::vector<uint8_t> hwcn(NumElements(&weights_tensor));
263     TransposeParams transpose_params;
264     transpose_params.perm_count = 4;
265     transpose_params.perm[0] = 1;
266     transpose_params.perm[1] = 2;
267     transpose_params.perm[2] = 3;
268     transpose_params.perm[3] = 0;
269     // TODO(b/151103818): Try merging Transpose & bit flip.
270     if (weights_tensor.type == kTfLiteInt8) {
271       optimized_ops::Transpose<int8_t>(transpose_params, nhwc_shape,
272                                        weights_tensor.data.int8, hwcn_shape,
273                                        reinterpret_cast<int8_t*>(hwcn.data()));
274       // Flip bits on the weight values so that the int8 values are treated
275       // as uint8.
276       for (int i = 0; i < hwcn.size(); ++i) {
277         hwcn[i] = hwcn[i] ^ k8BitSignFlipConstant;
278       }
279     } else {
280       optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
281                                         weights_tensor.data.uint8, hwcn_shape,
282                                         hwcn.data());
283     }
284     weights_data_node = graph_builder_->AddConstNodeWithData(
285         weight_shape_.data(), reinterpret_cast<char*>(hwcn.data()),
286         hwcn.size() * sizeof(hwcn[0]));
287   } else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) {
288     // Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the
289     // expected filter shape is [fh,fw,din,dmul].
290     // The data itself will remain the same, since TFLite's representation is
291     // just a 'flattening' of Hexagon's version.
292     const int channel_multiplier = weights_depth_size / input_depth;
293     weight_shape_ = {weights_height_size, weights_width_size, input_depth,
294                      channel_multiplier};
295 
296     // Check if the op hits the Depthwise conv accuracy issue.
297     // See Conv2dOpBuilder.should_split_dwconv_ for details.
298     CheckShouldSplitDwConv(weights_tensor.type, input_depth,
299                            is_per_channel_quant, channel_multiplier);
300 
301     if (weights_tensor.type == kTfLiteInt8) {
302       // Flip bits on the weight values so that the int8 values are treated
303       // as uint8.
304       std::vector<uint8_t> converted_data(NumElements(&weights_tensor));
305       for (int i = 0; i < converted_data.size(); ++i) {
306         converted_data[i] = weights_tensor.data.int8[i] ^ k8BitSignFlipConstant;
307       }
308       weights_data_node = graph_builder_->AddConstNodeWithData(
309           weight_shape_.data(), reinterpret_cast<char*>(converted_data.data()),
310           converted_data.size() * sizeof(converted_data[0]));
311       if (should_split_dwconv_)
312         SplitWeightsForDwConv(converted_data, input_depth, channel_multiplier);
313     } else {
314       weights_data_node = graph_builder_->AddConstNodeWithData(
315           weight_shape_.data(), weights_tensor.data.raw,
316           NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0]));
317     }
318   }
319   graph_builder_->AddTensorWithID(inputs->data[1], weights_data_node->GetID(),
320                                   0, /*overwrite=*/true);
321 
322   // WEIGHTS QUANTIZATION.
323   float weights_min = 0;
324   float weights_max = 0;
325   if (is_per_channel_quant) {
326     ProcessPerChannelQuantizedWeights(weights_tensor, context, &weights_min,
327                                       &weights_max, graph_builder_,
328                                       &per_channel_quant_);
329   } else {
330     TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
331         weights_tensor, &weights_min, &weights_max));
332   }
333   weights_min_node_ = graph_builder_->AddConstNodeWithData(
334       quant_bound_shape.data(), reinterpret_cast<char*>(&weights_min),
335       sizeof(weights_min));
336   weights_max_node_ = graph_builder_->AddConstNodeWithData(
337       quant_bound_shape.data(), reinterpret_cast<char*>(&weights_max),
338       sizeof(weights_max));
339 
340   return kTfLiteOk;
341 }
342 
SplitBiasForDwConv(std::vector<int> & preprocessed_bias_data)343 void Conv2dOpBuilder::SplitBiasForDwConv(
344     std::vector<int>& preprocessed_bias_data) {
345   // Splits bias to 32 channel batches.
346   std::vector<int> bias_shape = {1, 1, 1, kDwConv5x5Filt2x2StrideChannelCount};
347   for (auto i = 0; i < per_channel_quant_.splits; i++) {
348     auto offset = kDwConv5x5Filt2x2StrideChannelCount * i;
349     auto* bias_data_node = graph_builder_->AddConstNodeWithData(
350         bias_shape.data(),
351         reinterpret_cast<char*>(preprocessed_bias_data.data() + offset),
352         kDwConv5x5Filt2x2StrideChannelCount *
353             sizeof(preprocessed_bias_data[0]));
354     bias_nodes_.push_back(bias_data_node);
355   }
356 }
357 
InitializeBiasNodes(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context)358 TfLiteStatus Conv2dOpBuilder::InitializeBiasNodes(const TfLiteIntArray* inputs,
359                                                   const TfLiteIntArray* outputs,
360                                                   TfLiteContext* context) {
361   const std::vector<int> quant_bound_shape = {1, 1, 1, 1};
362 
363   const auto& bias_tensor = context->tensors[inputs->data[2]];
364 
365   float bias_min = 0;
366   float bias_max = 0;
367   if (per_channel_quant_.channel_scales_node != nullptr) {
368     std::vector<int> preprocessed_bias_data;
369     ProcessPerChannelQuantizedBias(
370         context->tensors[inputs->data[0]], bias_tensor, inputs->data[2],
371         context, &bias_min, &bias_max, graph_builder_, &per_channel_quant_,
372         &preprocessed_bias_data);
373     if (should_split_dwconv_) SplitBiasForDwConv(preprocessed_bias_data);
374   } else {
375     auto* bias_data_node =
376         graph_builder_->AddConstNodeWithData(inputs->data[2], bias_tensor);
377     graph_builder_->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0,
378                                     /*overwrite=*/true);
379     TF_LITE_ENSURE_STATUS(
380         ComputeMinAndMaxQuantValues(bias_tensor, &bias_min, &bias_max));
381   }
382 
383   bias_min_node_ = graph_builder_->AddConstNodeWithData(
384       quant_bound_shape.data(), reinterpret_cast<char*>(&bias_min),
385       sizeof(bias_min));
386   bias_max_node_ = graph_builder_->AddConstNodeWithData(
387       quant_bound_shape.data(), reinterpret_cast<char*>(&bias_max),
388       sizeof(bias_max));
389 
390   return kTfLiteOk;
391 }
392 
393 }  // namespace hexagon
394 }  // namespace delegates
395 }  // namespace tflite
396