1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/flat_hash_set.h" 21 #include "absl/types/span.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/resource_mgr.h" 25 #include "tensorflow/core/framework/tensor_shape.pb.h" 26 #include "tensorflow/core/grappler/costs/graph_properties.h" 27 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" 28 #include "tensorflow/core/grappler/utils.h" 29 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 30 31 namespace tensorflow { 32 namespace grappler { 33 34 const char kConstantFoldingConst[] = "ConstantFolding"; 35 const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; 36 extern const int64_t kMaxConstantSize; 37 38 // Constant folding optimization for a graph. 39 class ConstantFolding : public GraphOptimizer { 40 public: 41 // The size limit will only be considered if the newly created node is greater 42 // than original_size (optional). 43 static Status CreateNodeDef(const string& name, const TensorValue& tensor, 44 NodeDef* node, size_t original_size = 0); 45 static string AddControlDependency(const string& input_name, GraphDef* graph, 46 NodeMap* node_map); 47 48 explicit ConstantFolding(DeviceBase* cpu_device, 49 bool disable_compressed_tensor_optimization = false, 50 bool fold_quantization_emulation = true); 51 ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device, 52 bool disable_compressed_tensor_optimization = false, 53 bool fold_quantization_emulation = true); 54 ~ConstantFolding()55 ~ConstantFolding() override {} 56 name()57 string name() const override { return "constant_folding"; }; 58 UsesFunctionLibrary()59 bool UsesFunctionLibrary() const override { return false; } 60 61 Status Optimize(Cluster* cluster, const GrapplerItem& item, 62 GraphDef* output) override; 63 64 private: 65 bool ForwardInputs(NodeDef* node, absl::Span<const int> inputs_to_forward); 66 string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; 67 bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const; 68 69 bool IsReallyConstant(const NodeDef& node) const; 70 71 bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor); 72 73 Status MaterializeShapes(const GraphProperties& properties); 74 75 Status MaterializeBroadcastGradientArgs(const NodeDef& node, 76 const GraphProperties& properties); 77 Status MaterializeReductionIndices(NodeDef* node, 78 const GraphProperties& properties); 79 Status MaterializeConstantValuedNode(NodeDef* node, 80 const GraphProperties& properties); 81 Status MaterializeOutputValues(NodeDef* node, 82 const GraphProperties& properties); 83 Status MaterializeConstants(const GraphProperties& properties); 84 85 bool IsFoldable(const NodeDef& node, const GraphProperties* properties); 86 bool IsFoldableUncached(const NodeDef& node, 87 const GraphProperties* properties) const; 88 bool MaybeFoldable(const NodeDef& node, 89 const GraphProperties* properties) const; 90 91 Status EvaluateNode(const NodeDef& node, 92 const gtl::InlinedVector<TensorValue, 4>& inputs, 93 gtl::InlinedVector<TensorValue, 4>* output) const; 94 95 Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs, 96 bool* result_too_large); 97 98 Status FoldMergeNode(NodeDef* node, GraphDef* output_graph); 99 Status FoldNode(NodeDef* node, GraphDef* output_graph, 100 bool* result_too_large); 101 102 bool IsOnes(const NodeDef& node) const; 103 bool IsZeros(const NodeDef& node) const; 104 bool ReplaceOperationWithBroadcastTo(int input_to_broadcast, 105 const GraphProperties& properties, 106 NodeDef* node, GraphDef* graph); 107 void ReplaceOperationWithIdentity(int input_to_forward, 108 const GraphProperties& properties, 109 NodeDef* node, GraphDef* graph); 110 void ReplaceOperationWithSnapshot(int input_to_forward, 111 const GraphProperties& properties, 112 NodeDef* node, GraphDef* graph); 113 void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties, 114 GraphDef* graph); 115 void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast, 116 const GraphProperties& properties, 117 NodeDef* node, GraphDef* graph); 118 void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); 119 Status ReplaceOperationWithConstant(double value, 120 const GraphProperties& properties, 121 const TensorShapeProto& shape, 122 NodeDef* node, GraphDef* graph); 123 124 // Notice: Destroys *value. 125 Status ReplaceOperationWithConstantTensor(DataType dtype, TensorProto* value, 126 NodeDef* node, GraphDef* graph); 127 128 void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); 129 Status FoldGraph(const GraphProperties& properties, GraphDef* output, 130 absl::flat_hash_set<string>* nodes_to_not_simplify); 131 132 Status IsSimplifiableReshape(const NodeDef& node, 133 const GraphProperties& properties) const; 134 Status SimplifyGraph(GraphDef* optimized_graph, GraphProperties* properties, 135 absl::flat_hash_set<string>* nodes_to_not_simplify); 136 Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph, 137 GraphProperties* properties); 138 139 Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item, 140 GraphProperties* properties, 141 GraphDef* optimized_graph); 142 143 // Applies partial constant folding for Concat which is not commutative. 144 // Returns true if the transformation applied successfully. 145 bool PartialConcatConstFolding(GraphDef* optimized_graph, 146 GraphProperties* properties, NodeDef* node); 147 148 // Applies partial constant folding for associative operators AddN and 149 // AccumulateNV2. Returns true if the transformation applied successfully. 150 bool PartialAssocOpConstFolding(GraphDef* optimized_graph, 151 GraphProperties* properties, NodeDef* node); 152 153 // Applies partial constant propagation through IdentityN operator. 154 // Returns true if the transformation applied successfully. 155 bool PartialConstPropThroughIdentityN(NodeDef* node); 156 157 struct ConstantPushDownContext { 158 NodeDef* op_child; 159 NodeDef* const_child; 160 bool left_child_is_const; 161 bool right_child_is_const; 162 NodeDef* left_leaf; 163 NodeDef* right_leaf; 164 bool left_leaf_is_const; 165 bool right_leaf_is_const; 166 167 // Shape & type information. 168 const std::vector<OpInfo::TensorProperties>* parent_input_props; 169 const std::vector<OpInfo::TensorProperties>* op_child_input_props; 170 }; 171 172 // Populates ctx with pointers to the nodes in expression tree for which 173 // constant pushdown optimization is being considered, corresponding to one of 174 // the following configurations: 175 // 176 // parent parent 177 // / \ / \ 178 // op_child const_child const_child op_child 179 // / \ / \ 180 // left_leaf right_leaf left_leaf right_leaf 181 // 182 // Returns true if the expression is possible amenable for optimization. 183 // Returns false if must_have_properties is true and input properties for 184 // parent and op_child are not known. 185 bool PrepareConstantPushDown(const NodeDef& parent, 186 const GraphProperties& properties, 187 bool must_have_properties, 188 ConstantPushDownContext* ctx) const; 189 190 // Pushes down constants on '+', '-', '*', and '/' operators if applicable. 191 // Returns true if the transformation applied successfully. 192 bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph, 193 NodeDef* node); 194 195 // Pushes down constants on '+' and 'BiasAdd' operators if applicable. 196 // Returns true if the graph was modified. 197 bool ConstantPushDownBiasAdd(GraphProperties* properties, 198 GraphDef* optimized_graph, NodeDef* node); 199 200 // Aggregate constants present around a conv operator. Returns true if the 201 // transformation was applied successfully. 202 bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node, 203 const GraphProperties& properties); 204 205 // Strength reduces floating point division by a constant Div(x, const) to 206 // multiplication by the reciprocal Mul(x, Reciprocal(const)). 207 bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node); 208 209 // Simplifies arithmetic operations with ones or zeros. Returns the status, 210 // and updates the success input argument that denotes if any simplification 211 // was applied. 212 Status SimplifyArithmeticOperations(const GraphProperties& properties, 213 bool use_shape_info, 214 GraphDef* optimized_graph, NodeDef* node); 215 216 // Simplifies a Reshape operation to an Identity operation if applicable. 217 bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, 218 NodeDef* node); 219 220 // Returns true iff the node is a reduction and its reduction indices are 221 // constant. Sets *indices_is_empty to true if the set of dimensions to reduce 222 // along is empty (this happens often in the gradient graphs). 223 bool IsReductionWithConstantIndices(const NodeDef& node, 224 bool* indices_is_empty) const; 225 // Returns true if theres a possibility that a Reduce node could be simplified 226 // to an Identity/Reshape. 227 bool IsReductionCandidateForSimplification( 228 const NodeDef& node, const GraphProperties& properties, 229 TensorShapeProto* input_tensor_shape, 230 TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const; 231 // Returns true iff this reduction can be reduced to an identity (i.e if the 232 // input dimensions to reduce along are all of size 1 and keep_dims is true). 233 bool IsReductionSimplifiableToIdentity( 234 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims, 235 const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const; 236 // Changes a reduction into an Identity op, returning true on success. 237 bool ReplaceReductionWithIdentity(NodeDef* node) const; 238 239 // Simplifies a Reduction operation to an Identity/Reshape operation if 240 // applicable. 241 bool SimplifyReduction(GraphDef* optimized_graph, 242 const GraphProperties& properties, NodeDef* node); 243 244 // Switch(x, x) will always feed false to its false branch and true to 245 // its true branch. By rewriting the graph a bit, we can propagate these 246 // constants down the two output branches, and just use control dependencies 247 // to trigger the selected one at runtime. For example, 248 // 249 // +------+ 250 // x-->|Switch|-->a (in practice there may be multiple consumers of each 251 // x-->| |-->b output branch.) 252 // +------+ 253 // 254 // Is rewritten as 255 // 256 // +------+ 257 // x-->|Switch|-->Identity--^>Const(false)-->a 258 // x-->| |-->Identity--^>Const(true)-->b 259 // +------+ 260 bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node); 261 262 // Moves constants past Enter node if applicable. 263 bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node); 264 265 // Simplifies Pack operation if applicable. 266 bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); 267 268 // Simplifies a Squeeze operation to an Identity operation if applicable. 269 void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, 270 GraphDef* optimized_graph, NodeDef* node); 271 272 // Simplifies a Pad operation to an Identity operation if applicable. 273 Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, 274 GraphDef* optimized_graph, NodeDef* node); 275 276 // Simplifies a Tile operation to an Identity operation if applicable. 277 Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, 278 GraphDef* optimized_graph, NodeDef* node); 279 280 // Simplifies a StridedSlice operation to an Identity operation if applicable. 281 Status SimplifyStridedSlice(const GraphProperties& properties, 282 bool use_shape_info, GraphDef* optimized_graph, 283 NodeDef* node); 284 285 // Simplifies a Slice operation to an Identity operation if applicable. 286 Status SimplifySlice(const GraphProperties& properties, bool use_shape_info, 287 GraphDef* optimized_graph, NodeDef* node); 288 289 // Simplify a Case operation where the output_idx is known. 290 bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node); 291 292 // Simplify a Select operation where the predicates are all true or all false. 293 bool SimplifySelect(const GraphProperties& properties, 294 GraphDef* optimized_graph, NodeDef* node); 295 296 // Replaces variable updates that are effectively no-ops with NoOp nodes. 297 void RemoveRedundantVariableUpdates(GraphProperties* properties, 298 GraphDef* optimized_graph, NodeDef* node); 299 300 // Removes Reverse op over dimensions with size 1. 301 Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, 302 GraphDef* optimized_graph, NodeDef* node); 303 304 // Removes RandomShuffle op if it is scalar or first dimension is of size 1. 305 void RemoveRandomShuffle(const GraphProperties& properties, 306 bool use_shape_info, GraphDef* optimized_graph, 307 NodeDef* node); 308 309 // Removes Shuffle or Transpose op over dimensions of size 1. 310 Status RemoveShuffleOrTranspose(const GraphProperties& properties, 311 bool use_shape_info, 312 GraphDef* optimized_graph, NodeDef* node); 313 314 // Removes Split or SplitV node if possible. 315 void RemoveSplitOrSplitV(const GraphProperties& properties, 316 GraphDef* optimized_graph, NodeDef* node); 317 318 bool GetConcatAxis(const NodeDef& node, int* axis); 319 bool MergeConcat(bool use_shape_info, GraphProperties* properties, 320 GraphDef* optimized_graph, NodeDef* node); 321 322 Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, 323 GraphDef* optimized_graph); 324 325 // Points to an externally provided device or to owned_device_; 326 RewriterConfig::Toggle opt_level_; 327 DeviceBase* cpu_device_; 328 std::unique_ptr<DeviceBase> owned_device_; 329 330 std::unique_ptr<ResourceMgr> resource_mgr_; 331 GraphDef* graph_; 332 std::unique_ptr<NodeMap> node_map_; 333 std::unordered_set<string> nodes_to_preserve_; 334 // TODO(rmlarsen): Could these be keyed on absl::string_view? 335 absl::flat_hash_set<string> nodes_allowlist_; 336 absl::flat_hash_set<string> feed_nodes_; 337 absl::flat_hash_map<string, bool> maybe_foldable_nodes_; 338 bool has_fetch_; 339 bool graph_modified_; 340 bool graph_contains_assign_or_inplace_op_; 341 bool disable_compressed_tensor_optimization_; 342 bool fold_quantization_emulation_; 343 }; 344 345 } // end namespace grappler 346 } // end namespace tensorflow 347 348 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 349