xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_CONV_2D_BUILDER_H_
16 #define TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_CONV_2D_BUILDER_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h"
21 
22 namespace tflite {
23 namespace delegates {
24 namespace hexagon {
25 
26 // Stores quantization data for Conv/TransposeConv nodes.
27 // This information is used to handle the per-channel quantized weights & biases
28 // correctly in the Hexagon delegate.
29 struct PerChannelQuantData {
30   // This is initialized while processing quantized weights, and acts as an
31   // input to Hexagon Conv nodes.
32   OpBuilder* channel_scales_node = nullptr;
33   // Scale information is obtained from TfLiteAffineQuantization in the weights
34   // tensor.
35   float* scales_data = nullptr;
36   int num_scale_values = 1;
37   // Number of splits to workaround DepthwiseConv accuracy issue.
38   // See Conv2dOpBuilder.should_split_dwconv_ for details.
39   int splits = 0;
40   std::vector<OpBuilder*> channel_scales_nodes;
41 };
42 
43 class Conv2dOpBuilder : public OpBuilder {
44  public:
Conv2dOpBuilder(GraphBuilder * graph_builder,int op_type)45   explicit Conv2dOpBuilder(GraphBuilder* graph_builder, int op_type)
46       : OpBuilder(graph_builder, op_type) {}
47   TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
48                                 const TfLiteIntArray* outputs,
49                                 TfLiteContext* context) override;
50 
51   TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
52                                TfLiteContext* context) override;
53 
54   ~Conv2dOpBuilder() override;
55 
56  private:
57   TfLiteStatus InitializeWeightsNodes(const TfLiteIntArray* inputs,
58                                       const TfLiteIntArray* outputs,
59                                       TfLiteContext* context,
60                                       const int input_depth);
61 
62   TfLiteStatus InitializeBiasNodes(const TfLiteIntArray* inputs,
63                                    const TfLiteIntArray* outputs,
64                                    TfLiteContext* context);
65 
66   void BuildStandardConv(const TfLiteIntArray* inputs,
67                          const TfLiteTensor& output_data_tensor,
68                          OpBuilder* data_min_const, OpBuilder* data_max_const,
69                          OpBuilder* conv_output_min_const,
70                          OpBuilder* conv_output_max_const,
71                          OpBuilder* stride_node,
72                          const TfLitePadding padding_type,
73                          TensorID* output_tensor, TensorID* output_min_tensor,
74                          TensorID* output_max_tensor);
75   void BuildDilatedDwConv(const TfLiteIntArray* inputs,
76                           const TfLiteTensor& data_tensor,
77                           const TfLiteTensor& output_data_tensor,
78                           OpBuilder* data_min_const, OpBuilder* data_max_const,
79                           OpBuilder* conv_output_min_const,
80                           OpBuilder* conv_output_max_const,
81                           OpBuilder* stride_node, int stride_height,
82                           const TfLitePadding padding_type,
83                           TensorID* output_tensor, TensorID* output_min_tensor,
84                           TensorID* output_max_tensor);
85   void BuildSplittedDwConv(
86       const TfLiteIntArray* inputs, const TfLiteTensor& data_tensor,
87       const TfLiteTensor& output_data_tensor, OpBuilder* data_min_const,
88       OpBuilder* data_max_const, OpBuilder* conv_output_min_const,
89       OpBuilder* conv_output_max_const, OpBuilder* stride_node,
90       const TfLitePadding padding_type, TensorID* output_tensor,
91       TensorID* output_min_tensor, TensorID* output_max_tensor);
92 
93   TensorID node_output_;
94   std::vector<float> transposed_weights_;
95   std::vector<int> stride_shape_;
96   std::vector<int> weight_shape_;
97   OpBuilder* weights_min_node_ = nullptr;
98   OpBuilder* weights_max_node_ = nullptr;
99   OpBuilder* bias_min_node_ = nullptr;
100   OpBuilder* bias_max_node_ = nullptr;
101 
102   // TODO(b/228874753)
103   // We are seeing accuray issues on DepthwiseSupernode_8x8p32to8 in the
104   // following case:
105   // * kernel size is 5x5
106   // * stride size is 2x2
107   // * per channel quantized
108   // * input depth more than 32
109   //
110   // To workaround the issue, the DepthwiseSupernode_8x8p32to8 is splitted
111   // into 32 channel batches and concatenated afterwards.
112   // Input tensor, weights, bias and channel scales are splitted into 32
113   // channel sizes and fed to multiple DepthwiseSupernode_8x8p32to8 ops.
114   // The results are stitched back with a Concat op.
115   //
116   // Checks if it has DepthwiseSupernode_8x8p32to8 accuracy issues.
117   void CheckShouldSplitDwConv(TfLiteType weights_type, int input_depth,
118                               bool is_per_channel_quant,
119                               int channel_multiplier);
120   // Split weights into multiple 32-channel nodes.
121   // `converted_data` is MSB flipped int8 weight values.
122   void SplitWeightsForDwConv(const std::vector<uint8_t>& converted_data,
123                              int input_depth, int channel_multiplier);
124   // Split bias into 32 element batches.
125   // `preprocessed_bias_data` is the output of ProcessPerChannelQuantizedBias.
126   void SplitBiasForDwConv(std::vector<int>& preprocessed_bias_data);
127   bool should_split_dwconv_ = false;
128   std::vector<TensorID> data_nodes_;
129   std::vector<OpBuilder*> bias_nodes_;
130   std::vector<OpBuilder*> weights_nodes_;
131 
132   // Modified only if node has per-channel quantized weights/biases.
133   PerChannelQuantData per_channel_quant_;
134 
135   // Only used for dilated Depthwise Conv.
136   std::vector<int> dilation_factors_h_w_;
137   std::vector<int> space_to_batch_paddings_;
138   std::vector<int> batch_to_space_crops_;
139 };
140 
141 // ProcessPerChannelQuantizedWeights & ProcessPerChannelQuantizedBias can be
142 // used to pre-process per-channel quantized weights & biases for Hexagon.
143 // NOTE: ProcessPerChannelQuantizedWeights should be run before
144 // ProcessPerChannelQuantizedBias. This is becase we set PerChannelQuantData
145 // based on the weights tensor, which is utilized while preprocessing bias.
146 
147 TfLiteStatus ProcessPerChannelQuantizedWeights(
148     const TfLiteTensor& weights_tensor, TfLiteContext* context,
149     float* weights_min, float* weights_max, GraphBuilder* graph_builder,
150     PerChannelQuantData* per_channel_quant);
151 
152 TfLiteStatus ProcessPerChannelQuantizedBias(
153     const TfLiteTensor& data_tensor, const TfLiteTensor& bias_tensor,
154     const int bias_tensor_idx, TfLiteContext* context, float* bias_min,
155     float* bias_max, GraphBuilder* graph_builder,
156     PerChannelQuantData* per_channel_quant,
157     std::vector<int>* preprocessed_bias_data,
158     OpBuilder** bias_const_node = nullptr);
159 
160 }  // namespace hexagon
161 }  // namespace delegates
162 }  // namespace tflite
163 
164 #endif  // TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_CONV_2D_BUILDER_H_
165