xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/constant_folding.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/grappler/costs/graph_properties.h"
27 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
34 const char kConstantFoldingConst[] = "ConstantFolding";
35 const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
36 extern const int64_t kMaxConstantSize;
37 
38 // Constant folding optimization for a graph.
39 class ConstantFolding : public GraphOptimizer {
40  public:
41   // The size limit will only be considered if the newly created node is greater
42   // than original_size (optional).
43   static Status CreateNodeDef(const string& name, const TensorValue& tensor,
44                               NodeDef* node, size_t original_size = 0);
45   static string AddControlDependency(const string& input_name, GraphDef* graph,
46                                      NodeMap* node_map);
47 
48   explicit ConstantFolding(DeviceBase* cpu_device,
49                            bool disable_compressed_tensor_optimization = false,
50                            bool fold_quantization_emulation = true);
51   ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device,
52                   bool disable_compressed_tensor_optimization = false,
53                   bool fold_quantization_emulation = true);
54 
~ConstantFolding()55   ~ConstantFolding() override {}
56 
name()57   string name() const override { return "constant_folding"; };
58 
UsesFunctionLibrary()59   bool UsesFunctionLibrary() const override { return false; }
60 
61   Status Optimize(Cluster* cluster, const GrapplerItem& item,
62                   GraphDef* output) override;
63 
64  private:
65   bool ForwardInputs(NodeDef* node, absl::Span<const int> inputs_to_forward);
66   string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
67   bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;
68 
69   bool IsReallyConstant(const NodeDef& node) const;
70 
71   bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor);
72 
73   Status MaterializeShapes(const GraphProperties& properties);
74 
75   Status MaterializeBroadcastGradientArgs(const NodeDef& node,
76                                           const GraphProperties& properties);
77   Status MaterializeReductionIndices(NodeDef* node,
78                                      const GraphProperties& properties);
79   Status MaterializeConstantValuedNode(NodeDef* node,
80                                        const GraphProperties& properties);
81   Status MaterializeOutputValues(NodeDef* node,
82                                  const GraphProperties& properties);
83   Status MaterializeConstants(const GraphProperties& properties);
84 
85   bool IsFoldable(const NodeDef& node, const GraphProperties* properties);
86   bool IsFoldableUncached(const NodeDef& node,
87                           const GraphProperties* properties) const;
88   bool MaybeFoldable(const NodeDef& node,
89                      const GraphProperties* properties) const;
90 
91   Status EvaluateNode(const NodeDef& node,
92                       const gtl::InlinedVector<TensorValue, 4>& inputs,
93                       gtl::InlinedVector<TensorValue, 4>* output) const;
94 
95   Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs,
96                              bool* result_too_large);
97 
98   Status FoldMergeNode(NodeDef* node, GraphDef* output_graph);
99   Status FoldNode(NodeDef* node, GraphDef* output_graph,
100                   bool* result_too_large);
101 
102   bool IsOnes(const NodeDef& node) const;
103   bool IsZeros(const NodeDef& node) const;
104   bool ReplaceOperationWithBroadcastTo(int input_to_broadcast,
105                                        const GraphProperties& properties,
106                                        NodeDef* node, GraphDef* graph);
107   void ReplaceOperationWithIdentity(int input_to_forward,
108                                     const GraphProperties& properties,
109                                     NodeDef* node, GraphDef* graph);
110   void ReplaceOperationWithSnapshot(int input_to_forward,
111                                     const GraphProperties& properties,
112                                     NodeDef* node, GraphDef* graph);
113   void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties,
114                                 GraphDef* graph);
115   void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
116                                              const GraphProperties& properties,
117                                              NodeDef* node, GraphDef* graph);
118   void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
119   Status ReplaceOperationWithConstant(double value,
120                                       const GraphProperties& properties,
121                                       const TensorShapeProto& shape,
122                                       NodeDef* node, GraphDef* graph);
123 
124   // Notice: Destroys *value.
125   Status ReplaceOperationWithConstantTensor(DataType dtype, TensorProto* value,
126                                             NodeDef* node, GraphDef* graph);
127 
128   void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
129   Status FoldGraph(const GraphProperties& properties, GraphDef* output,
130                    absl::flat_hash_set<string>* nodes_to_not_simplify);
131 
132   Status IsSimplifiableReshape(const NodeDef& node,
133                                const GraphProperties& properties) const;
134   Status SimplifyGraph(GraphDef* optimized_graph, GraphProperties* properties,
135                        absl::flat_hash_set<string>* nodes_to_not_simplify);
136   Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
137                       GraphProperties* properties);
138 
139   Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item,
140                              GraphProperties* properties,
141                              GraphDef* optimized_graph);
142 
143   // Applies partial constant folding for Concat which is not commutative.
144   // Returns true if the transformation applied successfully.
145   bool PartialConcatConstFolding(GraphDef* optimized_graph,
146                                  GraphProperties* properties, NodeDef* node);
147 
148   // Applies partial constant folding for associative operators AddN and
149   // AccumulateNV2. Returns true if the transformation applied successfully.
150   bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
151                                   GraphProperties* properties, NodeDef* node);
152 
153   // Applies partial constant propagation through IdentityN operator.
154   // Returns true if the transformation applied successfully.
155   bool PartialConstPropThroughIdentityN(NodeDef* node);
156 
157   struct ConstantPushDownContext {
158     NodeDef* op_child;
159     NodeDef* const_child;
160     bool left_child_is_const;
161     bool right_child_is_const;
162     NodeDef* left_leaf;
163     NodeDef* right_leaf;
164     bool left_leaf_is_const;
165     bool right_leaf_is_const;
166 
167     // Shape & type information.
168     const std::vector<OpInfo::TensorProperties>* parent_input_props;
169     const std::vector<OpInfo::TensorProperties>* op_child_input_props;
170   };
171 
172   // Populates ctx with pointers to the nodes in expression tree for which
173   // constant pushdown optimization is being considered, corresponding to one of
174   // the following configurations:
175   //
176   //               parent                            parent
177   //               /    \                            /    \
178   //        op_child   const_child            const_child op_child
179   //         /     \                                       /     \
180   //    left_leaf  right_leaf                        left_leaf  right_leaf
181   //
182   // Returns true if the expression is possible amenable for optimization.
183   // Returns false if must_have_properties is true and input properties for
184   // parent and op_child are not known.
185   bool PrepareConstantPushDown(const NodeDef& parent,
186                                const GraphProperties& properties,
187                                bool must_have_properties,
188                                ConstantPushDownContext* ctx) const;
189 
190   // Pushes down constants on '+', '-', '*', and '/' operators if applicable.
191   // Returns true if the transformation applied successfully.
192   bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph,
193                         NodeDef* node);
194 
195   // Pushes down constants on '+' and 'BiasAdd' operators if applicable.
196   // Returns true if the graph was modified.
197   bool ConstantPushDownBiasAdd(GraphProperties* properties,
198                                GraphDef* optimized_graph, NodeDef* node);
199 
200   // Aggregate constants present around a conv operator. Returns true if the
201   // transformation was applied successfully.
202   bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
203                        const GraphProperties& properties);
204 
205   // Strength reduces floating point division by a constant Div(x, const) to
206   // multiplication by the reciprocal Mul(x, Reciprocal(const)).
207   bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);
208 
209   // Simplifies arithmetic operations with ones or zeros. Returns the status,
210   // and updates the success input argument that denotes if any simplification
211   // was applied.
212   Status SimplifyArithmeticOperations(const GraphProperties& properties,
213                                       bool use_shape_info,
214                                       GraphDef* optimized_graph, NodeDef* node);
215 
216   // Simplifies a Reshape operation to an Identity operation if applicable.
217   bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
218                        NodeDef* node);
219 
220   // Returns true iff the node is a reduction and its reduction indices are
221   // constant. Sets *indices_is_empty to true if the set of dimensions to reduce
222   // along is empty (this happens often in the gradient graphs).
223   bool IsReductionWithConstantIndices(const NodeDef& node,
224                                       bool* indices_is_empty) const;
225   // Returns true if theres a possibility that a Reduce node could be simplified
226   // to an Identity/Reshape.
227   bool IsReductionCandidateForSimplification(
228       const NodeDef& node, const GraphProperties& properties,
229       TensorShapeProto* input_tensor_shape,
230       TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const;
231   // Returns true iff this reduction can be reduced to an identity (i.e if the
232   // input dimensions to reduce along are all of size 1 and keep_dims is true).
233   bool IsReductionSimplifiableToIdentity(
234       const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
235       const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
236   // Changes a reduction into an Identity op, returning true on success.
237   bool ReplaceReductionWithIdentity(NodeDef* node) const;
238 
239   // Simplifies a Reduction operation to an Identity/Reshape operation if
240   // applicable.
241   bool SimplifyReduction(GraphDef* optimized_graph,
242                          const GraphProperties& properties, NodeDef* node);
243 
244   // Switch(x, x) will always feed false to its false branch and true to
245   // its true branch. By rewriting the graph a bit, we can propagate these
246   // constants down the two output branches, and just use control dependencies
247   // to trigger the selected one at runtime. For example,
248   //
249   //     +------+
250   // x-->|Switch|-->a  (in practice there may be multiple consumers of each
251   // x-->|      |-->b   output branch.)
252   //     +------+
253   //
254   // Is rewritten as
255   //
256   //     +------+
257   // x-->|Switch|-->Identity--^>Const(false)-->a
258   // x-->|      |-->Identity--^>Const(true)-->b
259   //     +------+
260   bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node);
261 
262   // Moves constants past Enter node if applicable.
263   bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node);
264 
265   // Simplifies Pack operation if applicable.
266   bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
267 
268   // Simplifies a Squeeze operation to an Identity operation if applicable.
269   void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
270                        GraphDef* optimized_graph, NodeDef* node);
271 
272   // Simplifies a Pad operation to an Identity operation if applicable.
273   Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
274                      GraphDef* optimized_graph, NodeDef* node);
275 
276   // Simplifies a Tile operation to an Identity operation if applicable.
277   Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
278                       GraphDef* optimized_graph, NodeDef* node);
279 
280   // Simplifies a StridedSlice operation to an Identity operation if applicable.
281   Status SimplifyStridedSlice(const GraphProperties& properties,
282                               bool use_shape_info, GraphDef* optimized_graph,
283                               NodeDef* node);
284 
285   // Simplifies a Slice operation to an Identity operation if applicable.
286   Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
287                        GraphDef* optimized_graph, NodeDef* node);
288 
289   // Simplify a Case operation where the output_idx is known.
290   bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node);
291 
292   // Simplify a Select operation where the predicates are all true or all false.
293   bool SimplifySelect(const GraphProperties& properties,
294                       GraphDef* optimized_graph, NodeDef* node);
295 
296   // Replaces variable updates that are effectively no-ops with NoOp nodes.
297   void RemoveRedundantVariableUpdates(GraphProperties* properties,
298                                       GraphDef* optimized_graph, NodeDef* node);
299 
300   // Removes Reverse op over dimensions with size 1.
301   Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
302                        GraphDef* optimized_graph, NodeDef* node);
303 
304   // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
305   void RemoveRandomShuffle(const GraphProperties& properties,
306                            bool use_shape_info, GraphDef* optimized_graph,
307                            NodeDef* node);
308 
309   // Removes Shuffle or Transpose op over dimensions of size 1.
310   Status RemoveShuffleOrTranspose(const GraphProperties& properties,
311                                   bool use_shape_info,
312                                   GraphDef* optimized_graph, NodeDef* node);
313 
314   // Removes Split or SplitV node if possible.
315   void RemoveSplitOrSplitV(const GraphProperties& properties,
316                            GraphDef* optimized_graph, NodeDef* node);
317 
318   bool GetConcatAxis(const NodeDef& node, int* axis);
319   bool MergeConcat(bool use_shape_info, GraphProperties* properties,
320                    GraphDef* optimized_graph, NodeDef* node);
321 
322   Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
323                                                GraphDef* optimized_graph);
324 
325   // Points to an externally provided device or to owned_device_;
326   RewriterConfig::Toggle opt_level_;
327   DeviceBase* cpu_device_;
328   std::unique_ptr<DeviceBase> owned_device_;
329 
330   std::unique_ptr<ResourceMgr> resource_mgr_;
331   GraphDef* graph_;
332   std::unique_ptr<NodeMap> node_map_;
333   std::unordered_set<string> nodes_to_preserve_;
334   // TODO(rmlarsen): Could these be keyed on absl::string_view?
335   absl::flat_hash_set<string> nodes_allowlist_;
336   absl::flat_hash_set<string> feed_nodes_;
337   absl::flat_hash_map<string, bool> maybe_foldable_nodes_;
338   bool has_fetch_;
339   bool graph_modified_;
340   bool graph_contains_assign_or_inplace_op_;
341   bool disable_compressed_tensor_optimization_;
342   bool fold_quantization_emulation_;
343 };
344 
345 }  // end namespace grappler
346 }  // end namespace tensorflow
347 
348 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
349