xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/graph_transformations/graph_transformations.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
16 #define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
17 
18 #include <cstddef>
19 #include <initializer_list>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "tensorflow/lite/toco/model.h"
25 #include "tensorflow/lite/toco/toco_port.h"
26 
27 namespace toco {
28 
29 class GraphTransformation {
30  public:
31   virtual ::tensorflow::Status Run(Model* model, std::size_t op_index,
32                                    bool* modified) = 0;
33   virtual const char* Name() const = 0;
~GraphTransformation()34   virtual ~GraphTransformation() {}
35   // Returns the list of messages that this graph transformation
36   // generated since ClearMessages() was called.
Messages()37   const std::vector<std::string>& Messages() const { return messages_; }
38   // Clears the list of messages; should be called after every
39   // run of this graph transformation.
ClearMessages()40   void ClearMessages() { return messages_.clear(); }
41   // Adds a message; normally only called by the graph transformation
42   // itself during its run (this function could be protected).
43   template <typename... Args>
AddMessageF(const char * format,const Args &...args)44   void AddMessageF(const char* format, const Args&... args) {
45     return messages_.push_back(toco::port::StringF(format, args...));
46   }
47 
48  protected:
GraphTransformation()49   GraphTransformation() {}
50 
51   // List of messages generated by this graph transformation.
52   std::vector<std::string> messages_;
53 
54  private:
55   GraphTransformation(const GraphTransformation& other) = delete;
56   GraphTransformation(const GraphTransformation&& other) = delete;
57 };
58 
59 class GraphTransformationsSet {
60  public:
61   // The choice of a container with fully-specified iteration order
62   // ensures that graph transformations are always run in the same order,
63   // which avoids having toco randomly fail or produce different results
64   // depending on the toolchain. Ideally success/results should be independent
65   // of the order in which graph transformations are run, but that's
66   // unfortunately not currently guaranteed to be the case.
67   using TransformationsContainer =
68       std::vector<std::unique_ptr<GraphTransformation>>;
69 
GraphTransformationsSet()70   GraphTransformationsSet() {}
GraphTransformationsSet(const std::initializer_list<GraphTransformation * > transformations)71   GraphTransformationsSet(
72       const std::initializer_list<GraphTransformation*> transformations) {
73     for (GraphTransformation* t : transformations) {
74       Add(t);
75     }
76   }
Add(GraphTransformation * transformation)77   void Add(GraphTransformation* transformation) {
78     const std::string& name = transformation->Name();
79     CHECK(!names_.count(name));
80     names_.insert(name);
81     transformations_.emplace_back(transformation);
82   }
begin()83   TransformationsContainer::const_iterator begin() const {
84     return transformations_.begin();
85   }
end()86   TransformationsContainer::const_iterator end() const {
87     return transformations_.end();
88   }
empty()89   bool empty() const { return transformations_.empty(); }
90 
91  private:
92   GraphTransformationsSet(const GraphTransformationsSet& other) = delete;
93   GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
94   std::vector<std::unique_ptr<GraphTransformation>> transformations_;
95   // Names of transformations in the set. Only used to guard against dupes.
96   std::unordered_set<std::string> names_;
97 };
98 
99 // Run the given list of graph transformations on the model.
100 // The message is only for logging purposes.
101 // The transformations is a rvalue reference, indicating that
102 // nothing else will use these pointers. The user is supposed to
103 // construct GraphTransformation objects by using 'new', pass us
104 // the resulting raw pointers, and this RunGraphTransformations
105 // takes care of delete'ing these pointers.
106 tensorflow::Status RunGraphTransformationsWithStatus(
107     Model* model, const std::string& msg,
108     const GraphTransformationsSet& transformations);
109 
RunGraphTransformations(Model * model,const std::string & msg,const GraphTransformationsSet & transformations)110 inline void RunGraphTransformations(
111     Model* model, const std::string& msg,
112     const GraphTransformationsSet& transformations) {
113   auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
114   CHECK(s.ok()) << s.error_message();
115 }
116 
117 #define DECLARE_GRAPH_TRANSFORMATION(GTName)                     \
118   class GTName : public GraphTransformation {                    \
119    public:                                                       \
120     ::tensorflow::Status Run(Model* model, std::size_t op_index, \
121                              bool* modified) override;           \
122     const char* Name() const override { return #GTName; }        \
123   };
124 
125 // List of all graph transformations
126 DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2OrV3ToV1)127 DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2OrV3ToV1)
128 DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2OrV3ToV1)
129 DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
130 DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
131 DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape)
132 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd)
133 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialPackToReshape)
134 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat)
135 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
136 DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
137 DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
138 DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
139 DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
140 DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary)
141 DECLARE_GRAPH_TRANSFORMATION(GroupBidirectionalSequenceLstm)
142 DECLARE_GRAPH_TRANSFORMATION(GroupBidirectionalSequenceRnn)
143 DECLARE_GRAPH_TRANSFORMATION(GroupDynamicBidirectionalSequenceLstm)
144 DECLARE_GRAPH_TRANSFORMATION(GroupDynamicBidirectionalSequenceRnn)
145 DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
146 DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
147 DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
148 DECLARE_GRAPH_TRANSFORMATION(IdentifyHardSwish)
149 DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
150 DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
151 DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
152 DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
153 DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
154 DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
155 DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
156 DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
157 DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
158 DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits)
159 DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
160 DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
161 DECLARE_GRAPH_TRANSFORMATION(Quantize)
162 DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
163 DECLARE_GRAPH_TRANSFORMATION(RemoveSuccessiveTranspose)
164 DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
165 DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
166 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
167 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
168 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
169 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialFakeQuant)
170 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice)
171 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
172 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedMinMax)
173 DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp)
174 DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization)
175 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator)
176 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
177 DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
178 DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
179 DECLARE_GRAPH_TRANSFORMATION(ReadArrayMinmaxAndNarrowRangeFromFakeQuant)
180 DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary)
181 DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose)
182 DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
183 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
184 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
185 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
186 DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
187 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
188 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
189 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
190 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
191 DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
192 DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
193 DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul)
194 DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes)
195 DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
196 DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
197 DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes)
198 DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes)
199 DECLARE_GRAPH_TRANSFORMATION(ResolveReshapeAttributes)
200 DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
201 DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
202 DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
203 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
204 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
205 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
206 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
207 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
208 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
209 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
210 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather)
211 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSelect)
212 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTile)
213 DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero)
214 DECLARE_GRAPH_TRANSFORMATION(Dequantize)
215 DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
216 DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights)
217 DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars)
218 DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)
219 DECLARE_GRAPH_TRANSFORMATION(IdentifyNearestUpsample)
220 
221 class PropagateDefaultMinMax : public GraphTransformation {
222  public:
223   ::tensorflow::Status Run(Model* model, std::size_t op_index,
224                            bool* modified) override;
225   const char* Name() const override { return "PropagateDefaultMinMax"; }
226 
227   bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
228   void DefineTypeRange(ArrayDataType data_type, double min, double max) {
229     MinMax minmax;
230     minmax.min = min;
231     minmax.max = max;
232     type_ranges_.emplace_back(data_type, minmax);
233   }
234 
235  private:
236   bool SetArrayMinMax(const std::string& array_name, Array* array);
237   std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
238 };
239 
240 class RemoveTrivialReshape : public GraphTransformation {
241  public:
242   ::tensorflow::Status Run(Model* model, std::size_t op_index,
243                            bool* modified) override;
Name()244   const char* Name() const override { return "RemoveTrivialReshape"; }
treat_expand_dims_as_trivial()245   bool treat_expand_dims_as_trivial() const {
246     return treat_expand_dims_as_trivial_;
247   }
set_treat_expand_dims_as_trivial(bool val)248   void set_treat_expand_dims_as_trivial(bool val) {
249     treat_expand_dims_as_trivial_ = val;
250   }
251 
252  private:
253   bool treat_expand_dims_as_trivial_ = false;
254 };
255 
256 class ResolveConstantFakeQuant : public GraphTransformation {
257  public:
258   ::tensorflow::Status Run(Model* model, std::size_t op_index,
259                            bool* modified) override;
Name()260   const char* Name() const override { return "ResolveConstantFakeQuant"; }
261 
262   // True if the num_bits should adjust the final data type.
propagate_fake_quant_num_bits()263   bool propagate_fake_quant_num_bits() const {
264     return propagate_fake_quant_num_bits_;
265   }
set_propagate_fake_quant_num_bits(bool val)266   void set_propagate_fake_quant_num_bits(bool val) {
267     propagate_fake_quant_num_bits_ = val;
268   }
269 
270  private:
271   bool propagate_fake_quant_num_bits_ = false;
272 };
273 
274 class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
275  public:
276   ::tensorflow::Status Run(Model* model, std::size_t op_index,
277                            bool* modified) override;
Name()278   const char* Name() const override {
279     return "EnsureUint8WeightsSafeForFastInt8Kernels";
280   }
allow_nudging_weights()281   bool allow_nudging_weights() const { return allow_nudging_weights_; }
set_allow_nudging_weights(bool val)282   void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; }
283 
has_default_ranges_flag()284   bool has_default_ranges_flag() const { return has_default_ranges_flag_; }
set_has_default_ranges_flag(bool val)285   void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; }
286 
287  private:
288   bool allow_nudging_weights_ = false;
289   bool has_default_ranges_flag_ = false;
290 };
291 
292 class IdentifyDilatedConv : public GraphTransformation {
293  public:
294   ::tensorflow::Status Run(Model* model, std::size_t op_index,
295                            bool* modified) override;
Name()296   const char* Name() const override { return "IdentifyDilatedConv"; }
identify_depthwise_conv()297   bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
set_identify_depthwise_conv(bool val)298   void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
299 
300  private:
301   bool identify_depthwise_conv_ = true;
302 };
303 
304 #undef DECLARE_GRAPH_TRANSFORMATION
305 
306 }  // end namespace toco
307 
308 #endif  // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
309