xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/constant_folding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19 
20 #include <cmath>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/grappler/clusters/cluster.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/cleanup.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/cpu_info.h"
50 #include "tensorflow/core/platform/denormal.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/setround.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/bcast.h"
56 #include "tensorflow/core/util/saved_tensor_slice_util.h"
57 
58 namespace tensorflow {
59 namespace grappler {
60 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
61 
62 // We only fold/materialize constants smaller than 100kB.
63 const int64_t kMaxConstantSize = 100 * 1024;
64 
65 namespace {
66 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)67 bool AllValuesAre(const TensorProto& proto, const T& value) {
68   Tensor tensor;
69   if (!tensor.FromProto(proto)) {
70     return false;
71   }
72   auto values = tensor.flat<T>();
73   for (int i = 0; i < tensor.NumElements(); ++i) {
74     if (values(i) != value) {
75       return false;
76     }
77   }
78   return true;
79 }
80 
81 // Add new_input as a control input to node if it does not already depend on it.
82 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
83 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)84 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
85                           GraphDef* graph, NodeMap* node_map) {
86   bool already_exists = false;
87   for (const string& input : node->input()) {
88     if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
89       already_exists = true;
90       break;
91     }
92   }
93   if (!already_exists) {
94     const string ctrl_dep =
95         ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
96     node->add_input(ctrl_dep);
97     node_map->AddOutput(NodeName(ctrl_input), node->name());
98   }
99   return !already_exists;
100 }
101 
102 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)103 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
104                              GraphDef* graph, NodeMap* node_map) {
105   bool removed_input = false;
106   bool update_node_map = true;
107   const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
108   for (int i = 0; i < node->input_size(); ++i) {
109     const string& input = node->input(i);
110     if (old_input_ctrl_dep == input) {
111       if (IsControlInput(input)) {
112         node->mutable_input()->SwapElements(i, node->input_size() - 1);
113         node->mutable_input()->RemoveLast();
114         removed_input = true;
115       } else {
116         // There is a non-control input from the same node.
117         // Don't remove the output from the NodeMap.
118         update_node_map = false;
119       }
120     }
121   }
122   if (update_node_map) {
123     node_map->RemoveOutput(NodeName(old_input), node->name());
124   }
125   return removed_input;
126 }
127 
HasTPUAttributes(const NodeDef & node)128 bool HasTPUAttributes(const NodeDef& node) {
129   AttrSlice attrs(node);
130   for (const auto& attr : attrs) {
131     if (attr.first.find("_tpu_") != attr.first.npos) {
132       return true;
133     }
134   }
135   return false;
136 }
137 
138 template <typename T>
PackedValuesNotEqual(T a,T b)139 bool PackedValuesNotEqual(T a, T b) {
140   return a != b;
141 }
142 
143 template <>
PackedValuesNotEqual(float a,float b)144 bool PackedValuesNotEqual(float a, float b) {
145   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
146 }
147 
148 template <>
PackedValuesNotEqual(double a,double b)149 bool PackedValuesNotEqual(double a, double b) {
150   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
151 }
152 
QuantizedTypeMinAsFloat(DataType data_type)153 float QuantizedTypeMinAsFloat(DataType data_type) {
154   switch (data_type) {
155     case DT_QINT8:
156       return Eigen::NumTraits<qint8>::lowest();
157     case DT_QUINT8:
158       return Eigen::NumTraits<quint8>::lowest();
159     case DT_QINT16:
160       return Eigen::NumTraits<qint16>::lowest();
161     case DT_QUINT16:
162       return Eigen::NumTraits<quint16>::lowest();
163     case DT_QINT32:
164       return Eigen::NumTraits<qint32>::lowest();
165     default:
166       return 0.0f;
167   }
168 }
169 
QuantizedTypeMaxAsFloat(DataType data_type)170 float QuantizedTypeMaxAsFloat(DataType data_type) {
171   switch (data_type) {
172     case DT_QINT8:
173       return Eigen::NumTraits<qint8>::highest();
174     case DT_QUINT8:
175       return Eigen::NumTraits<quint8>::highest();
176     case DT_QINT16:
177       return Eigen::NumTraits<qint16>::highest();
178     case DT_QUINT16:
179       return Eigen::NumTraits<quint16>::highest();
180     case DT_QINT32:
181       return Eigen::NumTraits<qint32>::highest();
182     default:
183       return 0.0f;
184   }
185 }
186 
187 }  // namespace
188 
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_emulation)189 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
190                                  DeviceBase* cpu_device,
191                                  bool disable_compressed_tensor_optimization,
192                                  bool fold_quantization_emulation)
193     : opt_level_(opt_level),
194       cpu_device_(cpu_device),
195       disable_compressed_tensor_optimization_(
196           disable_compressed_tensor_optimization),
197       fold_quantization_emulation_(fold_quantization_emulation) {
198   resource_mgr_.reset(new ResourceMgr());
199 }
200 
ConstantFolding(DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_ops)201 ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
202                                  bool disable_compressed_tensor_optimization,
203                                  bool fold_quantization_ops)
204     : ConstantFolding(RewriterConfig::ON, cpu_device,
205                       disable_compressed_tensor_optimization,
206                       fold_quantization_ops) {}
207 
208 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)209 string ConstantFolding::AddControlDependency(const string& input_name,
210                                              GraphDef* graph,
211                                              NodeMap* node_map) {
212   if (IsControlInput(input_name)) {
213     return input_name;
214   }
215   const NodeDef* node = node_map->GetNode(input_name);
216   // Sanity check for missing node.
217   if (!node) {
218     return input_name;
219   }
220   if (!IsSwitch(*node)) {
221     return AsControlDependency(*node);
222   } else {
223     // We can't anchor control dependencies directly on the switch node: unlike
224     // other nodes only one of the outputs of the switch node will be generated
225     // when the switch node is executed, and we need to make sure the control
226     // dependency is only triggered when the corresponding output is triggered.
227     // We start by looking for an identity node connected to the output of the
228     // switch node, and use it to anchor the control dependency.
229     for (const NodeDef* output : node_map->GetOutputs(node->name())) {
230       if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
231         if (IsSameInput(output->name(), input_name)) {
232           return AsControlDependency(*output);
233         }
234       }
235     }
236     // We haven't found an existing node where we can anchor the control
237     // dependency: add a new identity node.
238     int port = 0;
239     string ctrl_dep_name = ParseNodeName(input_name, &port);
240     strings::StrAppend(&ctrl_dep_name, "_", port);
241     ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
242     const DataType output_type = node->attr().at("T").type();
243 
244     NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
245     if (added_node == nullptr) {
246       added_node = graph->add_node();
247       added_node->set_name(ctrl_dep_name);
248       added_node->set_op("Identity");
249       added_node->set_device(node->device());
250 
251       (*added_node->mutable_attr())["T"].set_type(output_type);
252       *added_node->add_input() = input_name;
253       node_map->AddNode(added_node->name(), added_node);
254       node_map->AddOutput(node->name(), added_node->name());
255     }
256     return AsControlDependency(*added_node);
257   }
258 }
259 
260 // Forward inputs at the given indices to outputs and add a control dependency
261 // on node.
ForwardInputs(NodeDef * node,absl::Span<const int> inputs_to_forward)262 bool ConstantFolding::ForwardInputs(NodeDef* node,
263                                     absl::Span<const int> inputs_to_forward) {
264   for (int input_idx : inputs_to_forward) {
265     if (input_idx < 0 || input_idx >= node->input_size()) {
266       return false;
267     }
268   }
269 
270   const auto& tmp = node_map_->GetOutputs(node->name());
271   const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
272   bool updated_graph = false;
273   for (int input_idx : inputs_to_forward) {
274     const string& input = node->input(input_idx);
275     if (IsControlInput(input) && consumers.size() > 1) {
276       continue;
277     }
278     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
279     if (input_node == nullptr) {
280       LOG(ERROR) << "Bad input: " << input;
281       break;
282     }
283     // Update each consumer.
284     for (NodeDef* consumer : consumers) {
285       bool add_dep = false;
286       for (int consumer_input_idx = 0;
287            consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
288         const string& consumer_input = consumer->input(consumer_input_idx);
289         if (IsControlInput(consumer_input)) {
290           break;
291         }
292         // It is illegal to add control dependencies to _Retval nodes, so we
293         // can't bypass value producing `node` and forward inputs to `consumer`.
294         if (IsRetval(*consumer)) {
295           break;
296         }
297         int output_idx;
298         const string input_node_name =
299             ParseNodeName(consumer_input, &output_idx);
300         if (input_node_name == node->name() && output_idx == input_idx) {
301           consumer->set_input(consumer_input_idx, input);
302           // We will keep the input from the node through a control
303           // dependency, so we only need to add the consumer as an output
304           // for the input node.
305           node_map_->AddOutput(NodeName(input), consumer->name());
306           add_dep = true;
307         }
308       }
309       if (add_dep) {
310         consumer->add_input(AsControlDependency(node->name()));
311         updated_graph = true;
312       }
313     }
314   }
315 
316   if (updated_graph) {
317     for (NodeDef* consumer : consumers) {
318       DedupControlInputs(consumer);
319     }
320   }
321   return updated_graph;
322 }
323 
324 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64_t value,const DataType & type,const int index,Tensor * tensor)325 static Status PutValueIntoTensor(const int64_t value, const DataType& type,
326                                  const int index, Tensor* tensor) {
327   if (type == DT_INT32) {
328     if (value >= INT_MAX) {
329       return Status(error::INVALID_ARGUMENT, "int32 overflow");
330     }
331     tensor->flat<int32>()(index) = static_cast<int32>(value);
332   } else {
333     tensor->flat<int64_t>()(index) = value;
334   }
335   return OkStatus();
336 }
337 
338 // Writes the given tensor shape into the given tensor.
339 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)340 static Status ConvertShapeToConstant(const string& op, const DataType& type,
341                                      const PartialTensorShape& shp,
342                                      Tensor* tensor) {
343   if (op == "Shape" || op == "ShapeN") {
344     *tensor = Tensor(type, TensorShape({shp.dims()}));
345     for (int i = 0; i < shp.dims(); ++i) {
346       TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
347     }
348   } else if (op == "Size") {
349     int64_t size = 1;
350     for (int i = 0; i < shp.dims(); ++i) {
351       size *= shp.dim_size(i);
352     }
353     *tensor = Tensor(type, TensorShape({}));
354     TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
355   } else {
356     CHECK_EQ(op, "Rank");
357     *tensor = Tensor(type, TensorShape({}));
358     TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
359   }
360   return OkStatus();
361 }
362 
363 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const364 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
365                                           StringPiece suffix) const {
366   return node_map_->NodeExists(OptimizedNodeName(node, suffix));
367 }
368 
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const369 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
370                                           StringPiece suffix) const {
371   return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
372                              kConstantFoldingConst);
373 }
374 
IsReallyConstant(const NodeDef & node) const375 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
376   if (!IsConstant(node)) {
377     return false;
378   }
379   // If the node is fed it's not constant anymore.
380   return feed_nodes_.find(node.name()) == feed_nodes_.end();
381 }
382 
383 // TODO(rmlarsen): Refactor to shared util.
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)384 bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
385                                              Tensor* tensor) {
386   const NodeDef* node = node_map_->GetNode(node_name_or_input);
387   return node != nullptr && IsReallyConstant(*node) &&
388          CheckAttrExists(*node, "value").ok() &&
389          tensor->FromProto(node->attr().at("value").tensor());
390 }
391 
392 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)393 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
394   // We may add some nodes to the graph to encode control dependencies and hold
395   // the materialized shapes: there is no need to process these added nodes, so
396   // only iterate over the nodes of the input graph.
397   const int node_count = graph_->node_size();
398   for (int node_idx = 0; node_idx < node_count; ++node_idx) {
399     NodeDef* node = graph_->mutable_node(node_idx);
400     const string op = node->op();
401     if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
402         op != "TensorArraySizeV3") {
403       continue;
404     }
405     const std::vector<OpInfo::TensorProperties>& output =
406         properties.GetOutputProperties(node->name());
407     const std::vector<OpInfo::TensorProperties>& input =
408         properties.GetInputProperties(node->name());
409     if (input.empty() || output.empty()) {
410       continue;
411     }
412 
413     if (op == "Shape" || op == "Size" || op == "Rank") {
414       CHECK_EQ(1, output.size());
415       CHECK_EQ(1, input.size());
416 
417       const DataType type = output[0].dtype();
418       CHECK(type == DT_INT32 || type == DT_INT64);
419       const PartialTensorShape shape(input[0].shape());
420 
421       if ((op != "Rank" && !shape.IsFullyDefined()) ||
422           (op == "Rank" && shape.unknown_rank())) {
423         continue;
424       }
425 
426       Tensor constant_value(type);
427       if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
428         continue;
429       }
430 
431       // TODO(rmlarsen): Remove this workaround for b/150861569
432       // The bug involves an expression of the form Shape(ExpandDims(x)
433       // with an incorrectly inferred zero-size first dimension.
434       if (op == "Shape") {
435         if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
436       }
437 
438       // Repurpose the existing node to be the constant.
439       // Device placement is preserved.
440       graph_modified_ = true;
441       node->set_op("Const");
442       EraseRegularNodeAttributes(node);
443       (*node->mutable_attr())["dtype"].set_type(type);
444       constant_value.AsProtoTensorContent(
445           (*node->mutable_attr())["value"].mutable_tensor());
446 
447       // Turn the data input into a control dependency: this is needed to
448       // ensure that the constant value will only be run in the
449       // cases where the shape/rank/size would have been run in
450       // the original graph.
451       string ctrl_dep =
452           AddControlDependency(node->input(0), graph_, node_map_.get());
453       node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
454       node->set_input(0, ctrl_dep);
455       // Done with the Shape/Size/Rank node, move to the next node.
456       continue;
457     }
458 
459     if (op == "TensorArraySizeV3") {
460       const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
461       if (array->input_size() == 0 ||
462           (array->attr().count("dynamic_size") != 0 &&
463            array->attr().at("dynamic_size").b())) {
464         continue;
465       }
466       const NodeDef* array_size =
467           CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
468       if (IsReallyConstant(*array_size)) {
469         // Don't materialize 0 sizes to avoid triggering incorrect static
470         // checks. A 0 sized array that can't grow isn't useful anyway.
471         if (array_size->attr().count("value") == 0) {
472           continue;
473         }
474         const TensorProto& raw_val = array_size->attr().at("value").tensor();
475         if (raw_val.dtype() != DT_INT32) {
476           continue;
477         }
478         Tensor value(raw_val.dtype(), raw_val.tensor_shape());
479         if (!value.FromProto(raw_val)) {
480           continue;
481         }
482         if (value.flat<int32>()(0) == 0) {
483           continue;
484         }
485 
486         graph_modified_ = true;
487         node->set_op("Const");
488         *node->mutable_attr() = array_size->attr();
489         node->set_input(0, AsControlDependency(NodeName(node->input(0))));
490         node->set_input(1, AddControlDependency(NodeName(node->input(1)),
491                                                 graph_, node_map_.get()));
492       }
493       continue;
494     }
495 
496     // Handle ShapeN materialization case.
497     // It's possible that not all input tensors have known shapes.
498     CHECK_EQ(op, "ShapeN");
499     CHECK_EQ(input.size(), output.size());
500     const NodeDef* const shape_n_node = node;
501     for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
502          ++port_idx) {
503       const DataType type = output[port_idx].dtype();
504       CHECK(type == DT_INT32 || type == DT_INT64);
505       const PartialTensorShape shape(input[port_idx].shape());
506       if (!shape.IsFullyDefined()) {
507         continue;
508       }
509       Tensor constant_value(type);
510       auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
511       if (!status.ok()) {
512         continue;
513       }
514 
515       // We make a copy because we mutate the nodes.
516       auto fanouts = node_map_->GetOutputs(shape_n_node->name());
517       // Find all nodes consuming this shape and connect them through the new
518       // constant node instead.
519       for (NodeDef* output : fanouts) {
520         // Track whether there are any direct edges left between shape_n_node
521         // and this output node after the transformation.
522         bool direct_edges_exist = false;
523         for (int k = 0; k < output->input_size(); ++k) {
524           int port;
525           const string node_name = ParseNodeName(output->input(k), &port);
526           if (node_name == shape_n_node->name() && port == port_idx) {
527             // Create a const node as ShapeN's output if not already.
528             const string const_name = OptimizedNodeName(
529                 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
530             if (node_map_->GetNode(const_name) == nullptr) {
531               NodeDef* added_node = graph_->add_node();
532               added_node->set_name(const_name);
533               added_node->set_op("Const");
534               added_node->set_device(shape_n_node->device());
535               node_map_->AddNode(added_node->name(), added_node);
536               (*added_node->mutable_attr())["dtype"].set_type(type);
537               constant_value.AsProtoTensorContent(
538                   (*added_node->mutable_attr())["value"].mutable_tensor());
539               // We add a control dependency to the original ShapeN node,
540               // so that the node will only be run if all inputs of the
541               // original ShapeN node are run.
542               string ctrl_dep = AddControlDependency(shape_n_node->name(),
543                                                      graph_, node_map_.get());
544               *added_node->add_input() = ctrl_dep;
545               node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
546             }
547             *output->mutable_input(k) = const_name;
548             node_map_->AddOutput(const_name, output->name());
549             graph_modified_ = true;
550           }
551           if (node_name == shape_n_node->name() && port != port_idx) {
552             direct_edges_exist = true;
553           }
554         }
555         if (!direct_edges_exist) {
556           node_map_->RemoveOutput(node->name(), output->name());
557         }
558       }
559     }
560   }
561 
562   return OkStatus();
563 }
564 
565 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64_t * min_id)566 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
567                   BCast::Vec* shape, int64_t* min_id) {
568   if (shape_node.op() == "Shape") {
569     const std::vector<OpInfo::TensorProperties>& prop1 =
570         properties.GetInputProperties(shape_node.name());
571     if (prop1.size() != 1) {
572       return false;
573     }
574     const TensorShapeProto& shp = prop1[0].shape();
575     if (shp.unknown_rank()) {
576       return false;
577     }
578     for (const auto& dim : shp.dim()) {
579       shape->push_back(dim.size());
580       *min_id = std::min<int64_t>(*min_id, dim.size());
581     }
582   } else {
583     if (shape_node.attr().count("value") == 0) {
584       return false;
585     }
586     const TensorProto& raw_val = shape_node.attr().at("value").tensor();
587     if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
588       return false;
589     }
590     Tensor value(raw_val.dtype(), raw_val.tensor_shape());
591     if (!value.FromProto(raw_val)) {
592       return false;
593     }
594     for (int j = 0; j < value.NumElements(); ++j) {
595       if (raw_val.dtype() == DT_INT64) {
596         shape->push_back(value.vec<int64_t>()(j));
597       } else {
598         shape->push_back(value.vec<int>()(j));
599       }
600     }
601   }
602   return true;
603 }
604 }  // namespace
605 
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)606 Status ConstantFolding::MaterializeBroadcastGradientArgs(
607     const NodeDef& node, const GraphProperties& properties) {
608   const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
609   const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
610   if (shape_node1 == nullptr ||
611       (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
612       shape_node2 == nullptr ||
613       (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
614     return OkStatus();
615   }
616 
617   // Don't optimize this again if it was already optimized and folded.
618   if (OptimizedNodeExists(node, "-folded-1") ||
619       OptimizedNodeExists(node, "-folded-2")) {
620     return OkStatus();
621   }
622   int64_t min_id = 0;
623   BCast::Vec shape1;
624   if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
625     return OkStatus();
626   }
627   BCast::Vec shape2;
628   if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
629     return OkStatus();
630   }
631   // A value of -1 means we don't known anything about the dimension. Replace
632   // the -1 values with unique dimension ids since we don't want two '-1'
633   // dimensions to be considered equal.
634   for (auto& id : shape1) {
635     if (id == -1) {
636       id = --min_id;
637     }
638   }
639   for (auto& id : shape2) {
640     if (id == -1) {
641       id = --min_id;
642     }
643   }
644 
645   // Beware: the reduction dimensions computed by the BCast class are valid iff
646   // we assume that two distinct symbolic dimensions can't be equal and a
647   // symbolic dimension can't be equal to 1. This is often but not always true,
648   // so to make this optimization safe we filter out these cases.
649   const int common_dims = std::min(shape1.size(), shape2.size());
650   for (int i = 0; i < common_dims; ++i) {
651     if (shape1[i] >= 0 && shape2[i] >= 0) {
652       continue;
653     }
654     if (shape1[i] != shape2[i]) {
655       // We're either dealing with 2 different symbolic dimensions or a symbolic
656       // and a know dimensions. We can't be sure whether both are equal or not,
657       // so we can't be sure whether we'll be broadcasting or not.
658       return OkStatus();
659     }
660   }
661   // These extra dims could be equal to 1, in which case there is no
662   // broadcasting. It could also be greater than 1, in which case there would
663   // be broadcasting. Since we don't know, we'll just punt.
664   for (int i = common_dims, end = shape1.size(); i < end; ++i) {
665     if (shape1[i] < 0) {
666       return OkStatus();
667     }
668   }
669   for (int i = common_dims, end = shape2.size(); i < end; ++i) {
670     if (shape2[i] < 0) {
671       return OkStatus();
672     }
673   }
674 
675   BCast bcast(shape1, shape2);
676   if (!bcast.IsValid()) {
677     return OkStatus();
678   }
679 
680   BCast::Vec reduce_dims[2];
681   reduce_dims[0] = bcast.grad_x_reduce_idx();
682   reduce_dims[1] = bcast.grad_y_reduce_idx();
683 
684   TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
685   const DataType type = node.attr().at("T").type();
686   NodeDef* out[2];
687   for (int j = 0; j < 2; ++j) {
688     int reduction_indices = reduce_dims[j].size();
689     Tensor value(type, TensorShape({reduction_indices}));
690     for (int i = 0; i < reduction_indices; ++i) {
691       if (type == DT_INT32) {
692         value.vec<int32>()(i) = reduce_dims[j][i];
693       } else {
694         value.vec<int64_t>()(i) = reduce_dims[j][i];
695       }
696     }
697     string const_name =
698         OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
699     out[j] = node_map_->GetNode(const_name);
700     if (out[j] == nullptr) {
701       out[j] = graph_->add_node();
702       TF_RETURN_IF_ERROR(
703           CreateNodeDef(const_name, TensorValue(&value), out[j]));
704       out[j]->set_device(node.device());
705       node_map_->AddNode(const_name, out[j]);
706       string ctrl_dep =
707           AddControlDependency(node.name(), graph_, node_map_.get());
708       *out[j]->add_input() = ctrl_dep;
709       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
710     }
711   }
712 
713   // We make a copy here since we might mutate the set.
714   const auto outputs = node_map_->GetOutputs(node.name());
715   for (NodeDef* output : outputs) {
716     for (int k = 0; k < output->input_size(); ++k) {
717       int port;
718       string node_name = ParseNodeName(output->input(k), &port);
719       if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
720         *output->mutable_input(k) = out[port]->name();
721         node_map_->UpdateInput(output->name(), node_name, out[port]->name());
722       }
723     }
724   }
725 
726   return OkStatus();
727 }
728 
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)729 Status ConstantFolding::MaterializeReductionIndices(
730     NodeDef* node, const GraphProperties& properties) {
731   if (node->input_size() < 2) {
732     return OkStatus();
733   }
734   const NodeDef* indices = node_map_->GetNode(node->input(1));
735   if (!indices || IsReallyConstant(*indices)) {
736     // The reduction indices are already constant, there's nothing to do.
737     return OkStatus();
738   }
739 
740   const std::vector<OpInfo::TensorProperties>& input_props =
741       properties.GetInputProperties(node->name());
742   if (input_props.size() != 2) {
743     return OkStatus();
744   }
745   const OpInfo::TensorProperties& input_prop = input_props[0];
746   if (input_prop.shape().unknown_rank()) {
747     // We can't do anything if we don't know the rank of the input.
748     return OkStatus();
749   }
750   const int input_rank = input_prop.shape().dim_size();
751   if (input_rank < 1) {
752     // Unexpected graph, don't try to change it.
753     return OkStatus();
754   }
755   const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
756   DataType dtype = reduction_indices_prop.dtype();
757   if (dtype != DT_INT32 && dtype != DT_INT64) {
758     return OkStatus();
759   }
760   PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
761   const int num_reduction_indices = reduction_indices_shape.num_elements();
762 
763   const std::vector<OpInfo::TensorProperties>& output_props =
764       properties.GetOutputProperties(node->name());
765   if (output_props.size() != 1) {
766     return OkStatus();
767   }
768   const OpInfo::TensorProperties& output_prop = output_props[0];
769   const int output_rank =
770       output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
771 
772   bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
773   if (!full_reduction) {
774     // A full reduction will generate a tensor of one of the shapes
775     // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
776     // elements in the output of the reduction, we may deduce it from reshape
777     // nodes following it.
778     for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
779       full_reduction = false;
780       if (!IsReshape(*fanout)) {
781         return OkStatus();
782       }
783       const std::vector<OpInfo::TensorProperties>& reshape_props =
784           properties.GetOutputProperties(fanout->name());
785       if (reshape_props.size() != 1) {
786         return OkStatus();
787       }
788       const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
789       PartialTensorShape shape(reshape_prop.shape());
790       if (shape.num_elements() != 1) {
791         return OkStatus();
792       } else {
793         full_reduction = true;
794       }
795     }
796     if (!full_reduction) {
797       return OkStatus();
798     }
799   }
800 
801   // We know it's a full reduction. We can generate the full set of indices to
802   // reduce as a constant node.
803   string const_name = OptimizedNodeName(*node, "-reduction_indices");
804   if (node_map_->GetNode(const_name)) {
805     return OkStatus();
806   }
807   NodeDef* reduction_indices = graph_->add_node();
808   Tensor value(dtype, TensorShape({input_rank}));
809   for (int i = 0; i < input_rank; ++i) {
810     if (dtype == DT_INT32) {
811       value.vec<int32>()(i) = i;
812     } else {
813       value.vec<int64_t>()(i) = i;
814     }
815   }
816   TF_RETURN_IF_ERROR(
817       CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
818 
819   reduction_indices->set_device(node->device());
820   string ctrl_dep =
821       AddControlDependency(node->input(1), graph_, node_map_.get());
822   *reduction_indices->add_input() = ctrl_dep;
823   node_map_->AddNode(const_name, reduction_indices);
824   node_map_->AddOutput(NodeName(ctrl_dep), const_name);
825 
826   node->set_input(1, reduction_indices->name());
827   node_map_->UpdateInput(node->name(), indices->name(),
828                          reduction_indices->name());
829 
830   return OkStatus();
831 }
832 
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)833 Status ConstantFolding::MaterializeConstantValuedNode(
834     NodeDef* node, const GraphProperties& properties) {
835   if (disable_compressed_tensor_optimization_) {
836     return OkStatus();
837   }
838   // Nodes that generate constant-valued outputs can be represented compactly in
839   // compressed format, regardless of their shape.
840   const std::vector<OpInfo::TensorProperties>& output_props =
841       properties.GetOutputProperties(node->name());
842   if (output_props.size() != 1) return OkStatus();
843   const auto& output_shape = output_props[0].shape();
844   if (!PartialTensorShape(output_shape).IsFullyDefined()) {
845     return OkStatus();
846   }
847   if (IsFill(*node)) {
848     const auto output_dtype = output_props[0].dtype();
849     NodeDef* input_node = nullptr;
850     for (int i = 0; i < 2; ++i) {
851       input_node = node_map_->GetNode(NodeName(node->input(i)));
852       if (input_node == nullptr || !IsReallyConstant(*input_node)) {
853         return OkStatus();
854       }
855     }
856     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
857 
858     // Copy the input tensor to the fill node, set the output shape and data
859     // type, and change the node type to Const.
860     TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
861     const TensorProto& input_tensor = input_node->attr().at("value").tensor();
862     if (!input_tensor.tensor_content().empty()) {
863       // Convert the value to repeated field format, so we can use the
864       // decompression mechanism to store only a single value in the constant
865       // node, even if the shape specified in the original Fill is large.
866       Tensor t;
867       if (!t.FromProto(input_tensor)) {
868         return errors::InvalidArgument(
869             "Could not construct Tensor form TensorProto in node: ",
870             input_node->name());
871       }
872       tensor->clear_tensor_content();
873       t.AsProtoField(tensor);
874     } else {
875       *tensor = input_tensor;
876     }
877     *(tensor->mutable_tensor_shape()) = output_shape;
878     (*node->mutable_attr())["dtype"].set_type(output_dtype);
879     node->mutable_attr()->erase("T");
880     node->mutable_attr()->erase("index_type");
881     node->set_op("Const");
882     for (int i = 0; i < 2; i++) {
883       // Change inputs to a control inputs.
884       const string ctrl_dep = AsControlDependency(node->input(i));
885       node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
886       node->set_input(i, ctrl_dep);
887     }
888     graph_modified_ = true;
889   } else {
890     double value =
891         (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
892     if (value >= 0) {
893       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
894           value, properties, output_shape, node, graph_));
895     }
896   }
897   return OkStatus();
898 }
899 
900 // Materialize output values inferred by the shape inference.
MaterializeOutputValues(NodeDef * node,const GraphProperties & properties)901 Status ConstantFolding::MaterializeOutputValues(
902     NodeDef* node, const GraphProperties& properties) {
903   const std::vector<OpInfo::TensorProperties>& output =
904       properties.GetOutputProperties(node->name());
905   if (output.size() != 1 || !output[0].has_value() ||
906       !IsFoldable(*node, &properties)) {
907     return OkStatus();
908   }
909 
910   // If this is a trivial Identity node with a constant input, just route the
911   // input around it.
912   if (IsIdentity(*node)) {
913     NodeDef* input = node_map_->GetNode(node->input(0));
914     if (IsReallyConstant(*input)) {
915       std::vector<int> inputs_to_forward;
916       std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
917       graph_modified_ = ForwardInputs(node, inputs_to_forward);
918       return OkStatus();
919     }
920   }
921   // Repurpose the existing node to be the constant.
922   // Device placement is preserved.
923   TensorProto value_copy = output[0].value();
924   return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
925                                             node, graph_);
926 }
927 
MaterializeConstants(const GraphProperties & properties)928 Status ConstantFolding::MaterializeConstants(
929     const GraphProperties& properties) {
930   const int node_count = graph_->node_size();
931   for (int i = 0; i < node_count; ++i) {
932     NodeDef& node = *graph_->mutable_node(i);
933     const string& op = node.op();
934     if (op == "BroadcastGradientArgs") {
935       TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
936     } else if (IsReduction(node)) {
937       TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
938     } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
939       TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
940     } else {
941       TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
942     }
943   }
944   return OkStatus();
945 }
946 
IsFoldable(const NodeDef & node,const GraphProperties * properties)947 bool ConstantFolding::IsFoldable(const NodeDef& node,
948                                  const GraphProperties* properties) {
949   string key = strings::StrCat(node.name(), "/", node.op());
950   auto it = maybe_foldable_nodes_.find(key);
951   if (it == maybe_foldable_nodes_.end()) {
952     it = maybe_foldable_nodes_
953              .emplace(std::move(key), MaybeFoldable(node, properties))
954              .first;
955   }
956   if (!it->second) {
957     return false;
958   } else {
959     return IsFoldableUncached(node, properties);
960   }
961 }
962 
IsFoldableUncached(const NodeDef & node,const GraphProperties * properties) const963 bool ConstantFolding::IsFoldableUncached(
964     const NodeDef& node, const GraphProperties* properties) const {
965   // Folding not applicable to ops with no inputs.
966   if (node.input().empty()) {
967     return false;
968   }
969   // We can only fold nodes if all their inputs are known statically, except in
970   // the case of a merge node that propagate the first inputs that becomes
971   // available, and therefore only requires a single constant input to be
972   // foldable.
973   bool merge_has_constant_input = false;
974   const bool is_merge = IsMerge(node);
975   for (const auto& input : node.input()) {
976     if (IsControlInput(input)) {
977       continue;
978     }
979     const NodeDef* input_node = node_map_->GetNode(input);
980     if (!input_node) {
981       return false;
982     }
983     bool is_const = IsReallyConstant(*input_node);
984     if (is_const) {
985       // Don't fold strings constants for now since this causes problems with
986       // checkpointing.
987       if (input_node->attr().count("dtype") == 0 ||
988           input_node->attr().at("dtype").type() == DT_STRING) {
989         return false;
990       }
991       // Special case: If a Merge node has at least one constant input that
992       // does not depend on a control input, we can fold it.
993       merge_has_constant_input |= !HasControlInputs(*input_node);
994     } else if (!is_merge) {
995       return false;
996     }
997   }
998   if (is_merge && !merge_has_constant_input) return false;
999   if (disable_compressed_tensor_optimization_ &&
1000       (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
1001     return false;
1002 
1003   // If we know the output shapes, make sure that the outputs are small enough
1004   // to materialize.
1005   if (properties != nullptr && properties->HasOutputProperties(node.name())) {
1006     const std::vector<OpInfo::TensorProperties>& input_props =
1007         properties->GetInputProperties(node.name());
1008     const std::vector<OpInfo::TensorProperties>& output_props =
1009         properties->GetOutputProperties(node.name());
1010     // Compute total size of inputs.
1011     int64_t input_size_bytes = 0;
1012     for (const auto& input_prop : input_props) {
1013       const PartialTensorShape input_shape(input_prop.shape());
1014       if (input_shape.IsFullyDefined()) {
1015         input_size_bytes +=
1016             input_shape.num_elements() * DataTypeSize(input_prop.dtype());
1017       }
1018     }
1019     for (const auto& output_prop : output_props) {
1020       PartialTensorShape output_shape;
1021       if (!PartialTensorShape::BuildPartialTensorShape(output_prop.shape(),
1022                                                        &output_shape)
1023                .ok()) {
1024         return false;
1025       }
1026       if (output_shape.IsFullyDefined()) {
1027         const int64_t num_bytes =
1028             output_shape.num_elements() * DataTypeSize(output_prop.dtype());
1029         if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
1030           // Do not fold nodes if the in-memory size of output is too large.
1031           // Notice that this is not exactly the same check used in
1032           // CreateNodeDef() where the actual encoded size is checked.
1033           return false;
1034         }
1035       }
1036     }
1037   }
1038 
1039   return true;
1040 }
1041 
MaybeFoldable(const NodeDef & node,const GraphProperties * properties) const1042 bool ConstantFolding::MaybeFoldable(const NodeDef& node,
1043                                     const GraphProperties* properties) const {
1044   // Skip constants, they're already folded
1045   if (IsConstant(node)) {
1046     return false;
1047   }
1048   // Don't fold stateful ops such as TruncatedNormal.
1049   if (!IsFreeOfSideEffect(node)) {
1050     return false;
1051   }
1052 
1053   // Skips nodes that must be preserved except allowlisted nodes.
1054   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
1055       nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1056     return false;
1057   }
1058 
1059   // Skip control flow nodes, they can't be folded.
1060   if (ModifiesFrameInfo(node)) {
1061     return false;
1062   }
1063 
1064   // Skips ops that don't benefit from folding.
1065   if (IsPlaceholder(node)) {
1066     return false;
1067   }
1068   // `FakeParam` op is used as a placeholder in If branch function. It doesn't
1069   // have a valid output when executed.
1070   if (IsFakeParam(node)) {
1071     return false;
1072   }
1073 
1074   if (node.op() == "AccumulateNV2") {
1075     return false;
1076   }
1077   // Removing LoopCond nodes can screw up the partitioner.
1078   if (node.op() == "LoopCond") {
1079     return false;
1080   }
1081 
1082   if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
1083     return false;
1084   }
1085 
1086   const string& op = node.op();
1087   if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
1088       op.find("Reader") != string::npos) {
1089     return false;
1090   }
1091   if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
1092     return false;
1093   }
1094 
1095   // Don't fold nodes that contain TPU attributes.
1096   // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
1097   // properly forward custom attributes, b/119051778.
1098   if (HasTPUAttributes(node)) {
1099     return false;
1100   }
1101 
1102   const OpDef* op_def = nullptr;
1103   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
1104   if (!status.ok()) {
1105     return false;
1106   }
1107   // Don't fold ops without outputs.
1108   if (op_def->output_arg_size() == 0) {
1109     return false;
1110   }
1111   // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
1112   // TODO(rmlarsen): Only do this for XLA_* devices.
1113   for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
1114     if (output_arg.type() == DT_VARIANT) {
1115       return false;
1116     }
1117   }
1118 
1119   // Don't fold nodes that have no outgoing edges except allowlisted nodes.
1120   // Such nodes could be introduced by an earlier constant folding pass and are
1121   // preserved in case users want to fetch their values; re-processing them
1122   // would lead to an error of adding a duplicated node to graph.
1123   const auto& outputs = node_map_->GetOutputs(node.name());
1124   if (outputs.empty() &&
1125       nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1126     return false;
1127   }
1128   return true;
1129 }
1130 
1131 namespace {
1132 
1133 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME)     \
1134   case DTYPE:                                      \
1135     t->add_##NAME##_val(static_cast<TYPE>(value)); \
1136     break;
1137 
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)1138 Status CreateConstantTensorAttrValue(DataType type, double value,
1139                                      const TensorShapeProto& shape,
1140                                      AttrValue* attr_tensor) {
1141   TensorProto* t = attr_tensor->mutable_tensor();
1142   t->set_dtype(type);
1143   *t->mutable_tensor_shape() = shape;
1144   switch (type) {
1145     case DT_HALF:
1146       t->add_half_val(
1147           Eigen::numext::bit_cast<uint16>(static_cast<Eigen::half>(value)));
1148       break;
1149     case DT_BFLOAT16:
1150       t->add_half_val(
1151           Eigen::numext::bit_cast<uint16>(static_cast<bfloat16>(value)));
1152       break;
1153       SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
1154       SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
1155       SET_TENSOR_VAL_CASE(DT_INT64, int64_t, int64);
1156       SET_TENSOR_VAL_CASE(DT_UINT64, int64_t, int64);
1157       SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
1158       SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
1159       SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
1160       SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
1161       SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
1162       SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
1163       SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
1164       SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
1165       SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
1166       SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
1167       SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1168       SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1169     default:
1170       return errors::InvalidArgument(
1171           "Unsupported type in CreateConstantTensorAttrValue: ",
1172           DataTypeString(type));
1173   }
1174   return OkStatus();
1175 }
1176 
1177 #undef SET_TENSOR_CAL_CASE
1178 
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1179 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1180                                     const GraphProperties& properties) {
1181   DataType dtype = DT_INVALID;
1182   if (node.attr().count("T") == 1) {
1183     dtype = node.attr().at("T").type();
1184   } else if (node.attr().count("dtype") == 1) {
1185     dtype = node.attr().at("dtype").type();
1186   } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1187     dtype = DT_BOOL;
1188   } else {
1189     auto output_props = properties.GetOutputProperties(node.name());
1190     if (!output_props.empty()) {
1191       dtype = output_props[0].dtype();
1192     }
1193   }
1194   return dtype;
1195 }
1196 
1197 // Checks whether the shape of the const input of the Mul op is valid to perform
1198 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1199 bool IsValidConstShapeForMulConvPushDown(
1200     const string& data_format, const TensorShapeProto& filter_shape,
1201     const TensorShapeProto& mul_const_input_shape) {
1202   // If the const is a scalar, or it has fewer or same number of dimensions
1203   // than the filter and it only has single element, the optimization should
1204   // work.
1205   if (mul_const_input_shape.dim_size() <=
1206           static_cast<int>(data_format.size()) &&
1207       TensorShape(mul_const_input_shape).num_elements() == 1) {
1208     return true;
1209   }
1210 
1211   // Otherwise, check the eligibility according to data format.
1212   if (data_format == "NHWC" || data_format == "NDHWC") {
1213     TensorShapeProto new_filter_shape;
1214     if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1215                              &new_filter_shape)) {
1216       return false;
1217     }
1218     if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1219       return false;
1220     }
1221     // Only the last dimension could be larger than one, since broadcasting over
1222     // the last dimension (the output channel) will result in invalid filter.
1223     for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1224       if (mul_const_input_shape.dim(i).size() > 1) return false;
1225     }
1226     return true;
1227   } else if (data_format == "NCHW" || data_format == "NCDHW") {
1228     // TODO(laigd): support NCHW and NCDHW (b/111214513).
1229     return false;
1230   }
1231   return false;
1232 }
1233 
1234 }  // namespace
1235 
1236 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1237 Status ConstantFolding::CreateNodeDef(const string& name,
1238                                       const TensorValue& tensor, NodeDef* node,
1239                                       size_t original_size) {
1240   node->set_name(name);
1241   node->set_op("Const");
1242 
1243   AttrValue attr_type;
1244   attr_type.set_type(tensor->dtype());
1245   node->mutable_attr()->insert({"dtype", attr_type});
1246 
1247   AttrValue attr_tensor;
1248   TensorProto* t = attr_tensor.mutable_tensor();
1249   bool optimized = false;
1250   size_t encoded_size;
1251   // Use the packed representation whenever possible to avoid generating large
1252   // graphdefs. Moreover, avoid repeating the last values if they're equal.
1253   if (tensor->NumElements() > 4) {
1254 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE)                      \
1255   {                                                                            \
1256     const auto* val_ptr = tensor->flat<TYPE>().data();                         \
1257     auto last = *val_ptr;                                                      \
1258     int64_t last_index = 0;                                                    \
1259     for (int64_t i = 0; i < tensor->NumElements(); ++i) {                      \
1260       TYPE cur = *val_ptr++;                                                   \
1261       if (PackedValuesNotEqual(cur, last)) {                                   \
1262         last = cur;                                                            \
1263         last_index = i;                                                        \
1264       }                                                                        \
1265     }                                                                          \
1266     encoded_size = (last_index + 1) * sizeof(FIELDTYPE);                       \
1267     if (encoded_size < kint32max) {                                            \
1268       optimized = true;                                                        \
1269       t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1);                 \
1270       const auto* src_ptr = tensor->flat<TYPE>().data();                       \
1271       auto* dst_ptr =                                                          \
1272           t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
1273       std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr);                   \
1274     }                                                                          \
1275   }                                                                            \
1276   break
1277 
1278     switch (tensor->dtype()) {
1279       case DT_FLOAT:
1280         POPULATE_TENSOR_PROTO(tensor, t, float, float);
1281       case DT_DOUBLE:
1282         POPULATE_TENSOR_PROTO(tensor, t, double, double);
1283       case DT_INT64:
1284         POPULATE_TENSOR_PROTO(tensor, t, int64_t, int64);
1285       case DT_UINT64:
1286         POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
1287       case DT_INT32:
1288         POPULATE_TENSOR_PROTO(tensor, t, int32_t, int);
1289       case DT_UINT32:
1290         POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
1291       case DT_INT16:
1292         POPULATE_TENSOR_PROTO(tensor, t, int16_t, int);
1293       case DT_UINT16:
1294         POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1295       case DT_INT8:
1296         POPULATE_TENSOR_PROTO(tensor, t, int8_t, int);
1297       case DT_UINT8:
1298         POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1299       case DT_BOOL:
1300         POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1301       default:
1302         /* Do nothing. */
1303         break;
1304     }
1305   }
1306   if (optimized) {
1307     // Also specify type and shape.
1308     t->set_dtype(tensor->dtype());
1309     tensor->shape().AsProto(t->mutable_tensor_shape());
1310   } else {
1311     // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1312     // DT_QUINT8
1313     tensor->AsProtoTensorContent(t);
1314     encoded_size = t->tensor_content().size();
1315   }
1316   node->mutable_attr()->insert({"value", attr_tensor});
1317 
1318   if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
1319     return errors::InvalidArgument(
1320         strings::StrCat("Can't fold ", name, ", its size would be too large (",
1321                         encoded_size, " >= ", kMaxConstantSize, " bytes)"));
1322   }
1323   return OkStatus();
1324 }
1325 
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1326 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1327                                      const TensorVector& inputs,
1328                                      TensorVector* output) const {
1329   return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1330                                               resource_mgr_.get(), output);
1331 }
1332 
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1333 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1334                                             std::vector<NodeDef>* outputs,
1335                                             bool* result_too_large) {
1336   TensorVector inputs;
1337   TensorVector output_tensors;
1338   auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1339     for (const auto& input : inputs) {
1340       delete input.tensor;
1341     }
1342     for (const auto& output : output_tensors) {
1343       if (output.tensor) {
1344         delete output.tensor;
1345       }
1346     }
1347   });
1348 
1349   size_t total_inputs_size = 0;
1350   for (const auto& input : node.input()) {
1351     const TensorId input_tensor = ParseTensorName(input);
1352     if (input_tensor.index() < 0) {
1353       // Control dependency
1354       break;
1355     }
1356     const NodeDef* input_node = node_map_->GetNode(input);
1357     if (!IsReallyConstant(*input_node)) {
1358       return Status(error::INVALID_ARGUMENT,
1359                     strings::StrCat("Can't fold ", node.name(), ", its ", input,
1360                                     " isn't constant"));
1361     }
1362     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1363     const TensorProto& raw_val = input_node->attr().at("value").tensor();
1364     if (raw_val.dtype() == DT_INVALID) {
1365       return Status(
1366           error::INVALID_ARGUMENT,
1367           strings::StrCat("A tensor in the input node, with TensorId of ",
1368                           input_tensor.ToString(),
1369                           " has a dtype of DT_INVALID."));
1370     }
1371     if (IsRefType(raw_val.dtype())) {
1372       return errors::InvalidArgument(
1373           "Not allowed to construct a tensor with reference dtype, got ",
1374           DataTypeString(raw_val.dtype()));
1375     }
1376     Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1377     if (!value->FromProto(raw_val)) {
1378       delete (value);
1379       return errors::InvalidArgument("Unable to make Tensor from proto for ",
1380                                      node.name(), " with shape ",
1381                                      raw_val.tensor_shape().DebugString());
1382     }
1383     inputs.emplace_back(value);
1384     total_inputs_size += value->TotalBytes();
1385   }
1386 
1387   TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1388   if (output_tensors.empty()) {
1389     return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1390   }
1391 
1392   outputs->resize(output_tensors.size());
1393   for (size_t i = 0; i < output_tensors.size(); i++) {
1394     string node_name = OptimizedNodeName(node, "-folded");
1395     if (output_tensors.size() > 1) {
1396       node_name = strings::StrCat(node_name, "-", i);
1397     }
1398     if (output_tensors[i].tensor) {
1399       Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1400                                total_inputs_size);
1401       if (!s.ok()) {
1402         *result_too_large = true;
1403         return s;
1404       }
1405     } else {
1406       // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1407       // switch that's not selected by the switch predicate).
1408       outputs->at(i) = NodeDef();
1409     }
1410   }
1411   return OkStatus();
1412 }
1413 
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1414 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1415   // Merge nodes are special, in the sense that they execute as soon as one of
1416   // their input is ready. We can therefore fold a merge node iff it has at
1417   // least one constant input without control dependency.
1418   // We still need to ensure that the nodes in the fanin of the merge node are
1419   // scheduled. We'll therefore add a control dependency from the merge node
1420   // to the folded constant. We end up with:
1421   //  * the merge node and its inputs are preserved as is
1422   //  * a new constant node C1, driven by the merge node through a control
1423   //  dependency, initialized to the value of the folded input
1424   //  * a new constant node C2, driven by the merge node through a control
1425   //  dependency, initialized to the index of the folded input
1426   //  * the fanout of the merge nodes is rewired to be driven by either C1 or
1427   //  C2.
1428   for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1429     const auto& input = node->input(input_index);
1430     if (IsControlInput(input)) {
1431       // Try the next input.
1432       continue;
1433     }
1434     NodeDef* input_node = node_map_->GetNode(input);
1435     if (!IsReallyConstant(*input_node)) {
1436       continue;
1437     }
1438     bool valid_input = true;
1439     for (const string& fanin_of_input : input_node->input()) {
1440       if (IsControlInput(fanin_of_input)) {
1441         valid_input = false;
1442         break;
1443       }
1444     }
1445     if (!valid_input) {
1446       // Try the next input
1447       continue;
1448     }
1449 
1450     string const_out_name = OptimizedNodeName(*node, "_const");
1451     string const_index_name = OptimizedNodeName(*node, "_index");
1452     if (node_map_->GetNode(const_out_name) ||
1453         node_map_->GetNode(const_index_name)) {
1454       // Intended name already exists.
1455       return errors::AlreadyExists(
1456           strings::StrCat(const_out_name, " or ", const_index_name,
1457                           " already present in the graph"));
1458     }
1459 
1460     NodeDef* const_out = output_graph->add_node();
1461     *const_out = *input_node;
1462     const_out->set_name(const_out_name);
1463     const_out->set_device(node->device());
1464     *const_out->add_input() = AsControlDependency(*node);
1465     node_map_->AddNode(const_out->name(), const_out);
1466     node_map_->AddOutput(node->name(), const_out->name());
1467 
1468     NodeDef* const_index = output_graph->add_node();
1469     const_index->set_op("Const");
1470     Tensor index(DT_INT32, TensorShape({}));
1471     index.flat<int32>()(0) = input_index;
1472     (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1473     index.AsProtoTensorContent(
1474         (*const_index->mutable_attr())["value"].mutable_tensor());
1475     const_index->set_name(const_index_name);
1476     const_index->set_device(node->device());
1477     *const_index->add_input() = AsControlDependency(*node);
1478     node_map_->AddNode(const_index->name(), const_index);
1479     node_map_->AddOutput(node->name(), const_index->name());
1480 
1481     // We make a copy because we mutate the nodes.
1482     auto outputs = node_map_->GetOutputs(node->name());
1483     for (NodeDef* output : outputs) {
1484       for (int i = 0; i < output->input_size(); i++) {
1485         int port;
1486         string node_name = ParseNodeName(output->input(i), &port);
1487         if (node_name == node->name()) {
1488           if (port == 0) {
1489             *output->mutable_input(i) = const_out->name();
1490             node_map_->AddOutput(const_out->name(), output->name());
1491           } else if (port == 1) {
1492             *output->mutable_input(i) = const_index->name();
1493             node_map_->AddOutput(const_index->name(), output->name());
1494           } else {
1495             // This is a control dependency (or an invalid edge since the
1496             // merge node has only 2 outputs): preserve them.
1497           }
1498         }
1499       }
1500     }
1501     return OkStatus();
1502   }
1503   return OkStatus();
1504 }
1505 
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1506 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1507                                  bool* result_too_large) {
1508   *result_too_large = false;
1509   if (IsMerge(*node)) {
1510     return FoldMergeNode(node, output_graph);
1511   }
1512 
1513   std::vector<NodeDef> const_nodes;
1514   TF_RETURN_IF_ERROR(
1515       EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1516   VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
1517 
1518   NodeDef* constant_output = nullptr;
1519   for (int i = 0, end = const_nodes.size(); i < end; i++) {
1520     NodeDef* const_node = &const_nodes[i];
1521     VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
1522     if (const_node->name().empty()) {
1523       // Dead output: we can't create a constant to encode its value, so we'll
1524       // just skip it. We'll preserve the edges that originate from that
1525       // output below to preserve the overall behavior of the graph wrt dead
1526       // edges.
1527       continue;
1528     }
1529 
1530     // Returns `true` iff `const_node` already has control input named `input`.
1531     const auto is_duplicate_control_input = [&](const string& input) -> bool {
1532       auto it = absl::c_find(const_node->input(), input);
1533       return it != const_node->input().end();
1534     };
1535 
1536     // Forward control dependencies.
1537     for (const string& input : node->input()) {
1538       // Forward control dependencies from folded node.
1539       if (IsControlInput(input)) {
1540         if (!is_duplicate_control_input(input)) {
1541           *const_node->add_input() = input;
1542         }
1543       }
1544 
1545       // Forward control dependencies from constant inputs to folded node.
1546       if (!IsControlInput(input)) {
1547         NodeDef* input_node = node_map_->GetNode(input);
1548         for (const string& fanin_of_input : input_node->input()) {
1549           if (!is_duplicate_control_input(fanin_of_input)) {
1550             *const_node->add_input() = fanin_of_input;
1551           }
1552         }
1553       }
1554     }
1555 
1556     // We rewrite the existing node if it only has a single output, and
1557     // create new nodes otherwise.
1558     if (const_nodes.size() == 1) {
1559       node->set_op("Const");
1560       // Note we need to clear the inputs in NodeMap before we clear the inputs
1561       // in the node, otherwise NodeMap would see empty inputs and effectively
1562       // does nothing.
1563       node_map_->RemoveInputs(node->name());
1564       node->clear_input();
1565       *node->mutable_input() = const_node->input();
1566       for (const auto& input : node->input()) {
1567         node_map_->AddOutput(NodeName(input), node->name());
1568       }
1569       *node->mutable_attr() = const_node->attr();
1570       break;
1571     } else {
1572       if (node_map_->GetNode(const_node->name())) {
1573         // Intended name already exists.
1574         return errors::AlreadyExists(strings::StrCat(
1575             const_node->name(), " already present in the graph"));
1576       }
1577       NodeDef* added_node = output_graph->add_node();
1578       *added_node = *const_node;
1579       added_node->set_device(node->device());
1580       node_map_->AddNode(added_node->name(), added_node);
1581       for (const auto& input : added_node->input()) {
1582         node_map_->AddOutput(NodeName(input), added_node->name());
1583       }
1584       // All the constant nodes encoding output values have the same control
1585       // dependencies (since these are the control dependencies of the node
1586       // we're trying to fold). Record one such constant node.
1587       constant_output = added_node;
1588     }
1589   }
1590 
1591   if (const_nodes.size() > 1) {
1592     // We make a copy because we mutate the nodes.
1593     auto outputs = node_map_->GetOutputs(node->name());
1594     for (NodeDef* output : outputs) {
1595       for (int i = 0; i < output->input_size(); i++) {
1596         int port;
1597         string node_name = ParseNodeName(output->input(i), &port);
1598         if (node_name == node->name()) {
1599           if (port < 0) {
1600             // Propagate control dependencies if possible. If not, we'll just
1601             // preserve the existing control dependencies.
1602             if (constant_output != nullptr) {
1603               node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1604                                      constant_output->name());
1605               *output->mutable_input(i) = AsControlDependency(*constant_output);
1606             }
1607           } else if (port < static_cast<int>(const_nodes.size()) &&
1608                      !const_nodes[port].name().empty()) {
1609             // Replace alive outputs with the corresponding constant.
1610             node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1611                                    const_nodes[port].name());
1612             *output->mutable_input(i) = const_nodes[port].name();
1613           } else {
1614             // Leave this edge alone.
1615             VLOG(3) << "Preserving edge from " << node->name() << ":" << port
1616                     << "[" << node->op() << "] to " << output->name() << ":"
1617                     << i << "[" << output->op() << "]";
1618           }
1619         }
1620       }
1621     }
1622     outputs = node_map_->GetOutputs(node->name());
1623     if (outputs.empty() && has_fetch_ &&
1624         nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1625       node_map_->RemoveInputs(node->name());
1626       node->clear_input();
1627     }
1628   }
1629   return OkStatus();
1630 }
1631 
FoldGraph(const GraphProperties & properties,GraphDef * optimized_graph,absl::flat_hash_set<string> * nodes_to_not_simplify)1632 Status ConstantFolding::FoldGraph(
1633     const GraphProperties& properties, GraphDef* optimized_graph,
1634     absl::flat_hash_set<string>* nodes_to_not_simplify) {
1635   // We build a new optimized_graph by inserting the folded nodes into it, then
1636   // copy other nodes that might be needed at the end of this function.
1637   absl::flat_hash_set<string> processed_nodes;
1638   std::deque<NodeDef*> queue;
1639   for (int i = 0; i < graph_->node_size(); i++) {
1640     const NodeDef& node = graph_->node(i);
1641     if (IsFoldable(node, &properties) &&
1642         !nodes_to_not_simplify->count(node.name())) {
1643       queue.push_back(graph_->mutable_node(i));
1644     }
1645   }
1646   while (!queue.empty()) {
1647     NodeDef* node = queue.front();
1648     queue.pop_front();
1649     if (processed_nodes.count(node->name())) {
1650       continue;
1651     }
1652     // We need to record a copy of output nodes before FoldNode() modifies it.
1653     // We also need to ensure that the fanout is sorted deterministically.
1654     std::vector<NodeDef*> fanout =
1655         node_map_->GetOutputsOrderedByNodeName(node->name());
1656     bool result_too_large = false;
1657     Status s = FoldNode(node, optimized_graph, &result_too_large);
1658     processed_nodes.insert(node->name());
1659     if (!s.ok()) {
1660       VLOG(1) << "Failed to fold node " << node->DebugString()
1661               << "\nError message: " << s;
1662       if (result_too_large) {
1663         nodes_to_not_simplify->emplace(node->name());
1664       }
1665     } else {
1666       for (auto& fanout_node : fanout) {
1667         if (IsFoldable(*fanout_node, &properties) &&
1668             !nodes_to_not_simplify->count(fanout_node->name())) {
1669           queue.push_back(fanout_node);
1670         }
1671       }
1672     }
1673   }
1674 
1675   // Delete the newly created nodes that don't feed anything.
1676   std::vector<int> nodes_to_delete;
1677   for (int i = 0; i < optimized_graph->node_size(); i++) {
1678     const auto& fanout = node_map_->GetOutputs(optimized_graph->node(i).name());
1679     if (fanout.empty()) nodes_to_delete.push_back(i);
1680   }
1681   EraseNodesFromGraph(std::move(nodes_to_delete), optimized_graph);
1682 
1683   for (int i = 0; i < graph_->node_size(); ++i) {
1684     NodeDef* node = graph_->mutable_node(i);
1685     // If no fetch nodes is provided, we conservatively
1686     // move all nodes in the original graph to the output, in case users need
1687     // to fetch their values.
1688     const auto& fanout = node_map_->GetOutputs(node->name());
1689     if (!fanout.empty() || !has_fetch_ ||
1690         nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
1691       *(optimized_graph->add_node()) = std::move(*node);
1692     }
1693   }
1694   return OkStatus();
1695 }
1696 
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1697 Status ConstantFolding::IsSimplifiableReshape(
1698     const NodeDef& node, const GraphProperties& properties) const {
1699   if (!IsReshape(node)) {
1700     return errors::Internal("Node ", node.name(), " is not a Reshape node");
1701   }
1702   if (2 > node.input_size()) {
1703     return errors::Internal("Node ", node.name(),
1704                             " must have at most 2 inputs but has ",
1705                             node.input_size());
1706   }
1707   const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1708   if (!IsReallyConstant(*new_shape)) {
1709     return errors::Internal("Node ", node.name(), " has shape ",
1710                             new_shape->DebugString(),
1711                             " which is not a constant");
1712   }
1713   TensorVector outputs;
1714   auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1715     for (const auto& output : outputs) {
1716       delete output.tensor;
1717     }
1718   });
1719 
1720   Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1721   if (!s.ok()) {
1722     return errors::Internal("Could not evaluate node ", node.name());
1723   }
1724   if (outputs.size() != 1) {
1725     return errors::Internal("Node ", node.name(),
1726                             " must have exactly 1 output but has ",
1727                             outputs.size());
1728   }
1729 
1730   const std::vector<OpInfo::TensorProperties>& props =
1731       properties.GetInputProperties(node.name());
1732   if (props.empty()) {
1733     return errors::Internal("Node ", node.name(), " has no properties");
1734   }
1735   const OpInfo::TensorProperties& prop = props[0];
1736   if (prop.dtype() == DT_INVALID) {
1737     return errors::Internal("Node ", node.name(), " has property ",
1738                             prop.DebugString(), " with invalid dtype");
1739   }
1740   const PartialTensorShape shape(prop.shape());
1741   if (!shape.IsFullyDefined()) {
1742     return errors::Internal("Node ", node.name(), " has property ",
1743                             prop.DebugString(), " with shape ",
1744                             shape.DebugString(), " which is not fully defined");
1745   }
1746 
1747   PartialTensorShape new_dims;
1748   if (outputs[0]->dtype() == DT_INT32) {
1749     std::vector<int32> shp;
1750     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1751       int32_t dim = outputs[0]->flat<int32>()(i);
1752       shp.push_back(dim);
1753     }
1754     s = TensorShapeUtils::MakeShape(shp, &new_dims);
1755     if (!s.ok()) return s;
1756   } else {
1757     std::vector<int64_t> shp;
1758     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1759       int64_t dim = outputs[0]->flat<int64_t>()(i);
1760       shp.push_back(dim);
1761     }
1762     s = TensorShapeUtils::MakeShape(shp, &new_dims);
1763     if (!s.ok()) return s;
1764   }
1765 
1766   if (!shape.IsCompatibleWith(new_dims)) {
1767     return errors::Internal("Expected shape ", shape.DebugString(),
1768                             "to be compatible with ", new_dims.DebugString());
1769   }
1770 
1771   return OkStatus();
1772 }
1773 
1774 #define IS_VALUE_CASE(DTYPE, VALUE)                   \
1775   case DTYPE:                                         \
1776     return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1777         node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1778 
1779 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1780 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1781 
IsOnes(const NodeDef & node) const1782 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1783   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1784     return false;
1785   }
1786   if (IsOnesLike(node)) return true;
1787   if (IsZerosLike(node)) return false;
1788   if (node.op() == "Fill") {
1789     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1790     return values != nullptr && IsOnes(*values);
1791   }
1792   if (node.op() != "Const") return false;
1793   if (node.attr().count("dtype") == 0) return false;
1794   const auto dtype = node.attr().at("dtype").type();
1795   switch (dtype) {
1796     IS_ONES_CASE(DT_BOOL);
1797     IS_ONES_CASE(DT_HALF);
1798     IS_ONES_CASE(DT_BFLOAT16);
1799     IS_ONES_CASE(DT_FLOAT);
1800     IS_ONES_CASE(DT_DOUBLE);
1801     IS_ONES_CASE(DT_COMPLEX64);
1802     IS_ONES_CASE(DT_COMPLEX128);
1803     IS_ONES_CASE(DT_UINT8);
1804     IS_ONES_CASE(DT_INT8);
1805     IS_ONES_CASE(DT_UINT16);
1806     IS_ONES_CASE(DT_INT16);
1807     IS_ONES_CASE(DT_INT32);
1808     IS_ONES_CASE(DT_INT64);
1809     IS_ONES_CASE(DT_QINT32);
1810     IS_ONES_CASE(DT_QINT16);
1811     IS_ONES_CASE(DT_QUINT16);
1812     IS_ONES_CASE(DT_QINT8);
1813     IS_ONES_CASE(DT_QUINT8);
1814     default:
1815       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1816       return false;
1817   }
1818   return false;
1819 }
1820 
IsZeros(const NodeDef & node) const1821 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1822   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1823     return false;
1824   }
1825   if (IsOnesLike(node)) return false;
1826   if (IsZerosLike(node)) return true;
1827   if (node.op() == "Fill") {
1828     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1829     return values != nullptr && IsZeros(*values);
1830   }
1831   if (!IsConstant(node)) return false;
1832   if (node.attr().count("dtype") == 0) return false;
1833   const auto dtype = node.attr().at("dtype").type();
1834   switch (dtype) {
1835     IS_ZEROS_CASE(DT_BOOL);
1836     IS_ZEROS_CASE(DT_HALF);
1837     IS_ZEROS_CASE(DT_BFLOAT16);
1838     IS_ZEROS_CASE(DT_FLOAT);
1839     IS_ZEROS_CASE(DT_DOUBLE);
1840     IS_ZEROS_CASE(DT_COMPLEX64);
1841     IS_ZEROS_CASE(DT_COMPLEX128);
1842     IS_ZEROS_CASE(DT_UINT8);
1843     IS_ZEROS_CASE(DT_INT8);
1844     IS_ZEROS_CASE(DT_UINT16);
1845     IS_ZEROS_CASE(DT_INT16);
1846     IS_ZEROS_CASE(DT_INT32);
1847     IS_ZEROS_CASE(DT_INT64);
1848     IS_ZEROS_CASE(DT_QINT32);
1849     IS_ZEROS_CASE(DT_QINT16);
1850     IS_ZEROS_CASE(DT_QUINT16);
1851     IS_ZEROS_CASE(DT_QINT8);
1852     IS_ZEROS_CASE(DT_QUINT8);
1853     default:
1854       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1855       return false;
1856   }
1857   return false;
1858 }
1859 
ReplaceOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1860 bool ConstantFolding::ReplaceOperationWithBroadcastTo(
1861     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1862     GraphDef* graph) {
1863   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1864   if (dtype == DT_INVALID) {
1865     return false;
1866   }
1867   const PartialTensorShape shape(
1868       properties.GetOutputProperties(node->name())[0].shape());
1869   if (!shape.IsFullyDefined()) {
1870     return false;
1871   }
1872   // Create constant node with shape.
1873   const string const_name = OptimizedNodeName(
1874       *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
1875   if (node_map_->GetNode(const_name) != nullptr) {
1876     return false;
1877   }
1878 
1879   Tensor shape_t;
1880   if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
1881     return false;
1882   }
1883   NodeDef tmp;
1884   if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
1885     return false;
1886   }
1887   NodeDef* const_node = graph->add_node();
1888   const_node->Swap(&tmp);
1889   const_node->set_device(node->device());
1890   node_map_->AddNode(const_name, const_node);
1891   for (int i = 0; i < node->input_size(); ++i) {
1892     if (i != input_to_broadcast) {
1893       // Add a control input on the unused input.
1894       string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
1895                                              node_map_.get());
1896       *const_node->add_input() = ctrl_dep;
1897       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
1898     }
1899   }
1900 
1901   // Rewrite `node` in-place to BroadcastTo.
1902   node->set_op("BroadcastTo");
1903   EraseRegularNodeAttributes(node);
1904   (*node->mutable_attr())["T"].set_type(dtype);
1905   (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
1906   // Set the designated input to BroadcastTo.
1907   node->mutable_input()->SwapElements(0, input_to_broadcast);
1908   // Keep all other inputs as control dependencies.
1909   for (int i = 1; i < node->input_size(); ++i) {
1910     if (IsControlInput(node->input(i))) {
1911       break;
1912     }
1913     const string ctrl_dep =
1914         AddControlDependency(node->input(i), graph, node_map_.get());
1915     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1916     node->set_input(i, ctrl_dep);
1917   }
1918   // Add the shape argument.
1919   *node->add_input() = const_node->name();
1920   node_map_->AddOutput(const_name, node->name());
1921   node->mutable_input()->SwapElements(1, node->input_size() - 1);
1922   return true;
1923 }
1924 
1925 // Replace an operation with Identity.
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1926 void ConstantFolding::ReplaceOperationWithIdentity(
1927     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1928     GraphDef* graph) {
1929   if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
1930   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1931   if (dtype == DT_INVALID) return;
1932 
1933   node->set_op("Identity");
1934   EraseRegularNodeAttributes(node);
1935   (*node->mutable_attr())["T"].set_type(dtype);
1936   // Propagate the designated input through the identity.
1937   node->mutable_input()->SwapElements(0, input_to_forward);
1938   // Add all other inputs as control dependencies.
1939   for (int i = 1; i < node->input_size(); ++i) {
1940     if (IsControlInput(node->input(i))) {
1941       break;
1942     }
1943     const string ctrl_dep =
1944         AddControlDependency(node->input(i), graph, node_map_.get());
1945     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1946     node->set_input(i, ctrl_dep);
1947   }
1948   graph_modified_ = true;
1949 }
1950 
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1951 void ConstantFolding::ReplaceOperationWithSnapshot(
1952     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1953     GraphDef* graph) {
1954   // If the graph contains no ops that mutate their inputs, we can
1955   // use Identity instead of Snapshot.
1956   if (!graph_contains_assign_or_inplace_op_) {
1957     ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1958     return;
1959   }
1960 
1961   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1962   if (dtype == DT_INVALID) return;
1963 
1964   node->set_op("Snapshot");
1965   EraseRegularNodeAttributes(node);
1966   (*node->mutable_attr())["T"].set_type(dtype);
1967   // Propagate the designated input through the Snapshot.
1968   node->mutable_input()->SwapElements(0, input_to_forward);
1969   // Add all other inputs as control dependencies.
1970   for (int i = 1; i < node->input_size(); ++i) {
1971     if (IsControlInput(node->input(i))) {
1972       break;
1973     }
1974     const string ctrl_dep =
1975         AddControlDependency(node->input(i), graph, node_map_.get());
1976     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1977     node->set_input(i, ctrl_dep);
1978   }
1979   graph_modified_ = true;
1980 }
1981 
1982 // Replace a node with NoOp. Change all inputs to control dependencies.
1983 // If the node has non-control outputs, no change will be performed.
ReplaceOperationWithNoOp(NodeDef * node,GraphProperties * properties,GraphDef * graph)1984 void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
1985                                                GraphProperties* properties,
1986                                                GraphDef* graph) {
1987   if (HasRegularOutputs(*node, *node_map_)) return;
1988   node->set_op("NoOp");
1989   EraseRegularNodeAttributes(node);
1990   EraseNodeOutputAttributes(node);
1991   // Erase attributes that describe output properties.
1992   properties->ClearOutputProperties(node->name());
1993   // Change all inputs to control dependencies.
1994   for (int i = 0; i < node->input_size(); ++i) {
1995     if (IsControlInput(node->input(i))) {
1996       break;
1997     }
1998     const string ctrl_dep =
1999         AddControlDependency(node->input(i), graph, node_map_.get());
2000     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2001     node->set_input(i, ctrl_dep);
2002   }
2003   DedupControlInputs(node);
2004   graph_modified_ = true;
2005 }
2006 
ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)2007 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
2008     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
2009     GraphDef* graph) {
2010   if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
2011                                        graph)) {
2012     return;
2013   }
2014   graph_modified_ = true;
2015 }
2016 
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)2017 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
2018                                                         GraphDef* graph) {
2019   node->set_op("Reciprocal");
2020   node->mutable_input()->SwapElements(0, 1);
2021   const string ctrl_dep =
2022       AddControlDependency(node->input(1), graph, node_map_.get());
2023   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
2024   node->set_input(1, ctrl_dep);
2025   graph_modified_ = true;
2026 }
2027 
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)2028 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
2029                                                            GraphDef* graph) {
2030   node->set_op("Neg");
2031   node->mutable_input()->SwapElements(0, 1);
2032   const string ctrl_dep =
2033       AddControlDependency(node->input(1), graph, node_map_.get());
2034   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
2035   node->set_input(1, ctrl_dep);
2036   graph_modified_ = true;
2037 }
2038 
ReplaceOperationWithConstantTensor(DataType dtype,TensorProto * value,NodeDef * node,GraphDef * graph)2039 Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
2040                                                            TensorProto* value,
2041                                                            NodeDef* node,
2042                                                            GraphDef* graph) {
2043   if (dtype == DT_VARIANT) return OkStatus();
2044   node->set_op("Const");
2045   EraseRegularNodeAttributes(node);
2046   (*node->mutable_attr())["dtype"].set_type(dtype);
2047   (*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
2048   // Convert all inputs to control dependencies.
2049   for (int i = 0; i < node->input_size(); ++i) {
2050     if (IsControlInput(node->input(i))) {
2051       break;
2052     }
2053     const string ctrl_dep =
2054         AddControlDependency(node->input(i), graph, node_map_.get());
2055     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2056     node->set_input(i, ctrl_dep);
2057   }
2058   DedupControlInputs(node);
2059   graph_modified_ = true;
2060   return OkStatus();
2061 }
2062 
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph)2063 Status ConstantFolding::ReplaceOperationWithConstant(
2064     double value, const GraphProperties& properties,
2065     const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
2066   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
2067   if (dtype == DT_VARIANT) return OkStatus();
2068   AttrValue tensor_attr;
2069   Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
2070   if (!s.ok()) {
2071     // Fail gracefully without mutating the graph.
2072     VLOG(1) << "Failed to replace node " << node->name() << " of type "
2073             << DataTypeString(dtype) << " with constant tensor of value "
2074             << value;
2075     return OkStatus();
2076   }
2077   return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
2078                                             node, graph);
2079 }
2080 
SimplifyGraph(GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)2081 Status ConstantFolding::SimplifyGraph(
2082     GraphDef* optimized_graph, GraphProperties* properties,
2083     absl::flat_hash_set<string>* nodes_to_not_simplify) {
2084   for (int i = 0; i < optimized_graph->node_size(); ++i) {
2085     NodeDef* node = optimized_graph->mutable_node(i);
2086     // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
2087     // generalize to only restrict certain simplifications.
2088     if (nodes_to_not_simplify->find(node->name()) ==
2089         nodes_to_not_simplify->end()) {
2090       if (HasTPUAttributes(*node)) {
2091         nodes_to_not_simplify->insert(node->name());
2092         continue;
2093       }
2094 
2095       TF_RETURN_IF_ERROR(SimplifyNode(node, optimized_graph, properties));
2096     }
2097   }
2098   return OkStatus();
2099 }
2100 
2101 #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
2102   TF_RETURN_IF_ERROR(EXPR);               \
2103   if (graph_modified_) return OkStatus()
2104 
2105 #define SET_AND_RETURN_IF_MODIFIED(EXPR) \
2106   graph_modified_ = EXPR;                \
2107   if (graph_modified_) return OkStatus()
2108 
2109 #define RETURN_IF_MODIFIED(EXPR) \
2110   EXPR;                          \
2111   if (graph_modified_) return OkStatus()
2112 
SimplifyNode(NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)2113 Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
2114                                      GraphProperties* properties) {
2115   bool graph_modified_cached = graph_modified_;
2116   graph_modified_ = false;
2117 
2118   bool use_shape_info = properties->has_properties();
2119   RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
2120   RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
2121       *properties, use_shape_info, optimized_graph, node));
2122   RETURN_IF_MODIFIED(
2123       RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
2124   RETURN_IF_ERROR_OR_MODIFIED(
2125       RemoveReverse(*properties, use_shape_info, optimized_graph, node));
2126   RETURN_IF_ERROR_OR_MODIFIED(
2127       SimplifySlice(*properties, use_shape_info, optimized_graph, node));
2128   RETURN_IF_ERROR_OR_MODIFIED(
2129       SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
2130   RETURN_IF_ERROR_OR_MODIFIED(
2131       SimplifyTile(*properties, use_shape_info, optimized_graph, node));
2132   RETURN_IF_ERROR_OR_MODIFIED(
2133       SimplifyPad(*properties, use_shape_info, optimized_graph, node));
2134   RETURN_IF_MODIFIED(
2135       SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
2136   SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
2137   SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
2138   SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
2139   SET_AND_RETURN_IF_MODIFIED(
2140       SimplifyReduction(optimized_graph, *properties, node));
2141   SET_AND_RETURN_IF_MODIFIED(
2142       SimplifyReshape(*properties, use_shape_info, node));
2143   RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
2144       *properties, use_shape_info, optimized_graph, node));
2145   SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
2146   SET_AND_RETURN_IF_MODIFIED(
2147       ConstantPushDown(properties, optimized_graph, node));
2148   SET_AND_RETURN_IF_MODIFIED(
2149       MulConvPushDown(optimized_graph, node, *properties));
2150   SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
2151   SET_AND_RETURN_IF_MODIFIED(
2152       PartialAssocOpConstFolding(optimized_graph, properties, node));
2153   SET_AND_RETURN_IF_MODIFIED(
2154       MergeConcat(use_shape_info, properties, optimized_graph, node));
2155   SET_AND_RETURN_IF_MODIFIED(
2156       PartialConcatConstFolding(optimized_graph, properties, node));
2157   SET_AND_RETURN_IF_MODIFIED(
2158       ConstantPushDownBiasAdd(properties, optimized_graph, node));
2159   SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
2160   SET_AND_RETURN_IF_MODIFIED(
2161       SimplifySelect(*properties, optimized_graph, node));
2162   RETURN_IF_MODIFIED(
2163       RemoveRedundantVariableUpdates(properties, optimized_graph, node));
2164 
2165   graph_modified_ = graph_modified_cached;
2166   return OkStatus();
2167 }
2168 
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2169 void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
2170                                           GraphDef* optimized_graph,
2171                                           NodeDef* node) {
2172   if (node->attr().count("num_split") == 0) return;
2173   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
2174     ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2175   }
2176   if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
2177     ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2178   }
2179 }
2180 
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2181 Status ConstantFolding::RemoveShuffleOrTranspose(
2182     const GraphProperties& properties, bool use_shape_info,
2183     GraphDef* optimized_graph, NodeDef* node) {
2184   if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
2185     return OkStatus();
2186   Tensor permutation_tensor;
2187   if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
2188       properties.HasInputProperties(node->name())) {
2189     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2190     std::vector<int> permutation;
2191     for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
2192       if (permutation_tensor.dtype() == DT_INT64) {
2193         permutation.push_back(permutation_tensor.vec<int64_t>()(j));
2194       } else {
2195         permutation.push_back(permutation_tensor.vec<int>()(j));
2196       }
2197     }
2198     int permutation_size = permutation.size();
2199     if (permutation_size != shape.dim_size()) {
2200       // Number of elements in perm should be same as dim_size. Skip if not.
2201       return OkStatus();
2202     }
2203     // The node is replaceable iff
2204     // dim_size == 0 || all dims have size 1 ||
2205     // all dims with > 1 size are not permuted.
2206     bool replaceable = true;
2207     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2208       replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
2209     }
2210     if (replaceable) {
2211       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2212     }
2213   }
2214   return OkStatus();
2215 }
2216 
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2217 void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
2218                                           bool use_shape_info,
2219                                           GraphDef* optimized_graph,
2220                                           NodeDef* node) {
2221   if (use_shape_info && IsRandomShuffle(*node) &&
2222       !properties.GetInputProperties(node->name()).empty()) {
2223     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2224     // The node is replaceable iff
2225     // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
2226     if (!shape.unknown_rank() &&
2227         (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
2228       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2229     }
2230   }
2231 }
2232 
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2233 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
2234                                       bool use_shape_info,
2235                                       GraphDef* optimized_graph,
2236                                       NodeDef* node) {
2237   if (!use_shape_info || node->op() != "ReverseV2") return OkStatus();
2238   Tensor axis;
2239   if (properties.HasInputProperties(node->name()) &&
2240       GetTensorFromConstNode(node->input(1), &axis)) {
2241     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2242     if (shape.unknown_rank()) return OkStatus();
2243     std::set<int> target_axes;
2244     for (int j = 0; j < axis.NumElements(); ++j) {
2245       // value of axis can be negative.
2246       if (axis.dtype() == DT_INT64) {
2247         target_axes.insert((axis.vec<int64_t>()(j) + shape.dim_size()) %
2248                            shape.dim_size());
2249       } else {
2250         target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2251                            shape.dim_size());
2252       }
2253     }
2254 
2255     // The node is replaceable iff
2256     // unknown_rank == false &&
2257     // (dim_size == 0 || all dims have size 1 ||
2258     //  all dims with > 1 size are not in target_axes)
2259     bool replaceable = true;
2260     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2261       replaceable &=
2262           shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
2263     }
2264     if (replaceable) {
2265       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2266     }
2267   }
2268   return OkStatus();
2269 }
2270 
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2271 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2272                                       bool use_shape_info,
2273                                       GraphDef* optimized_graph,
2274                                       NodeDef* node) {
2275   if (!use_shape_info || !IsSlice(*node)) return OkStatus();
2276   Tensor begin;
2277   Tensor size;
2278   if (properties.HasInputProperties(node->name()) &&
2279       GetTensorFromConstNode(node->input(1), &begin) &&
2280       GetTensorFromConstNode(node->input(2), &size)) {
2281     const auto& input = properties.GetInputProperties(node->name())[0];
2282     // The node is replaceable iff unknown_rank == false &&
2283     // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2284     bool replaceable = !input.shape().unknown_rank();
2285     for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2286       if (begin.dtype() == DT_INT32) {
2287         replaceable &= begin.vec<int>()(j) == 0;
2288       } else {
2289         replaceable &= begin.vec<int64_t>()(j) == 0;
2290       }
2291       if (size.dtype() == DT_INT32) {
2292         replaceable &= (size.vec<int>()(j) == -1 ||
2293                         size.vec<int>()(j) == input.shape().dim(j).size());
2294       } else {
2295         replaceable &= (size.vec<int64_t>()(j) == -1 ||
2296                         size.vec<int64_t>()(j) == input.shape().dim(j).size());
2297       }
2298     }
2299     if (replaceable) {
2300       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2301     }
2302   }
2303   return OkStatus();
2304 }
2305 
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2306 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2307                                              bool use_shape_info,
2308                                              GraphDef* optimized_graph,
2309                                              NodeDef* node) {
2310   if (use_shape_info && IsStridedSlice(*node) &&
2311       properties.GetInputProperties(node->name()).size() == 4) {
2312     TF_RETURN_IF_ERROR(
2313         CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2314     if (node->attr().at("new_axis_mask").i() != 0 ||
2315         node->attr().at("shrink_axis_mask").i() != 0) {
2316       // Skip nodes with new/shrink axis mask, since they involve dimension
2317       // changes.
2318       return OkStatus();
2319     }
2320     const auto& input = properties.GetInputProperties(node->name())[0];
2321     for (int j = 0; j < input.shape().dim_size(); ++j) {
2322       // Skip if input shape is not fully determined.
2323       if (input.shape().dim(j).size() < 0) {
2324         return OkStatus();
2325       }
2326     }
2327 
2328     std::vector<Tensor> input_tensors(3);
2329     for (int i = 1; i < 4; ++i) {
2330       if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
2331         return OkStatus();
2332       }
2333     }
2334 
2335     const Tensor& begin = input_tensors[0];
2336     const Tensor& end = input_tensors[1];
2337     const Tensor& strides = input_tensors[2];
2338 
2339     TF_RETURN_IF_ERROR(
2340         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2341     int begin_mask = node->attr().at("begin_mask").i();
2342     int end_mask = node->attr().at("end_mask").i();
2343     std::set<int> expanded_ellipsis_indices;
2344     int ellipsis_index = -1;
2345     for (int j = 0; j < input.shape().dim_size(); ++j) {
2346       // find the ellipsis_mask. If not found, insert one in the end if
2347       // necessary.
2348       if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2349           (ellipsis_index == -1 && j >= strides.NumElements())) {
2350         ellipsis_index = j;
2351       }
2352       // insert the indices that are immediately after ellipsis_index if
2353       // necessary.
2354       if (ellipsis_index != -1 &&
2355           input.shape().dim_size() >
2356               strides.NumElements() + j - ellipsis_index) {
2357         expanded_ellipsis_indices.insert(j);
2358       }
2359     }
2360 
2361     // The node is replaceable iff unknown_rank == false &&
2362     // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2363     //  && strides == 1) for all dimensions.
2364     bool replaceable = !input.shape().unknown_rank();
2365     for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2366       if (expanded_ellipsis_indices.find(j) !=
2367           expanded_ellipsis_indices.end()) {
2368         // ellipsis_mask is effective on current dimension.
2369         continue;
2370       }
2371       // when we have ellipsis_mask in between, input.shape().dim_size() will
2372       // be greater than strides.NumElements(), since we will insert
2373       // as many as expanded_ellipsis_indices.size() axes during computation.
2374       // We need to subtract this number from j.
2375       int i = j;
2376       int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
2377       if (ellipsis_index != -1 &&
2378           j >= ellipsis_index + expanded_ellipsis_indices_size) {
2379         i = j - expanded_ellipsis_indices_size;
2380       }
2381       int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2382                                         : begin.vec<int64_t>()(i);
2383       int e =
2384           end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64_t>()(i);
2385       int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2386                                           : strides.vec<int64_t>()(i);
2387       replaceable &= (begin_mask & 1 << i || b == 0) &&
2388                      (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
2389                      s == 1;
2390     }
2391     if (replaceable) {
2392       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2393     }
2394   }
2395   return OkStatus();
2396 }
2397 
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2398 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2399                                      bool use_shape_info,
2400                                      GraphDef* optimized_graph, NodeDef* node) {
2401   Tensor multiplies;
2402   if (use_shape_info && IsTile(*node) &&
2403       GetTensorFromConstNode(node->input(1), &multiplies)) {
2404     // The node is replaceable iff all values in multiplies are 1.
2405     bool replaceable = true;
2406     if (multiplies.dtype() == DT_INT32) {
2407       for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2408         replaceable &= multiplies.vec<int>()(j) == 1;
2409       }
2410     } else {
2411       for (int j = 0; replaceable && j < multiplies.vec<int64_t>().size();
2412            ++j) {
2413         replaceable &= multiplies.vec<int64_t>()(j) == 1;
2414       }
2415     }
2416     if (replaceable) {
2417       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2418     }
2419   }
2420   return OkStatus();
2421 }
2422 
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2423 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2424                                     bool use_shape_info,
2425                                     GraphDef* optimized_graph, NodeDef* node) {
2426   if (!use_shape_info || !IsPad(*node)) return OkStatus();
2427 
2428   Tensor paddings;
2429   if (GetTensorFromConstNode(node->input(1), &paddings)) {
2430     // The node is replaceable iff all values in paddings are 0.
2431     bool replaceable = true;
2432     if (paddings.dtype() == DT_INT32) {
2433       const auto flatten = paddings.flat<int32>();
2434       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2435         replaceable &= flatten(j) == 0;
2436       }
2437     } else {
2438       const auto flatten = paddings.flat<int64_t>();
2439       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2440         replaceable &= flatten(j) == 0;
2441       }
2442     }
2443     if (replaceable) {
2444       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2445     }
2446   }
2447   return OkStatus();
2448 }
2449 
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2450 void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2451                                       bool use_shape_info,
2452                                       GraphDef* optimized_graph,
2453                                       NodeDef* node) {
2454   if (use_shape_info && IsSqueeze(*node) &&
2455       !properties.GetInputProperties(node->name()).empty()) {
2456     // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2457     // error to squeeze a dimension that is not 1, so we only need to check
2458     // whether the input has > 1 size for each dimension.
2459     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2460     // The node is replaceable iff
2461     // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2462     bool replaceable = !shape.unknown_rank();
2463     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2464       replaceable &= shape.dim(j).size() > 1;
2465     }
2466     if (replaceable) {
2467       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2468     }
2469   }
2470 }
2471 
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2472 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2473   const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
2474   if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
2475       node_map_->NodeExists(axis_node_name)) {
2476     return false;
2477   }
2478 
2479   // It's unsafe to add a control dependency on the feed node, because it might
2480   // have been never executed otherwiwise.
2481   if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
2482     return false;
2483   }
2484 
2485   // Create constant axis node.
2486   Tensor axis_t(DT_INT32, TensorShape({}));
2487   const int axis =
2488       node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2489   NodeDef new_node;
2490   if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2491       !CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
2492     return false;
2493   }
2494   NodeDef* axis_node = optimized_graph->add_node();
2495   *axis_node = std::move(new_node);
2496   axis_node->set_name(axis_node_name);
2497   node_map_->AddNode(axis_node->name(), axis_node);
2498   // Add a control dependency to make sure axis_node is in the right frame.
2499   const string ctrl_dep = ConstantFolding::AddControlDependency(
2500       node->input(0), optimized_graph, node_map_.get());
2501   axis_node->add_input(ctrl_dep);
2502   axis_node->set_device(node->device());
2503   node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
2504   node->set_op("ExpandDims");
2505   if (node->attr().count("axis") != 0) {
2506     node->mutable_attr()->erase("axis");
2507   }
2508   if (node->attr().count("N") != 0) {
2509     node->mutable_attr()->erase("N");
2510   }
2511   (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2512   node->add_input(axis_node->name());
2513   node_map_->AddOutput(axis_node->name(), node->name());
2514   if (node->input_size() > 2) {
2515     node->mutable_input()->SwapElements(1, node->input_size() - 1);
2516   }
2517   return true;
2518 }
2519 
SimplifyCase(GraphDef * optimized_graph,NodeDef * node)2520 bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
2521   if (node->op() != "Case") return false;
2522   const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
2523   if (output_idx_node == nullptr ||
2524       !CheckAttrExists(*output_idx_node, "value").ok()) {
2525     return false;
2526   }
2527   Tensor output_idx_t;
2528   if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
2529     return false;
2530   int output_idx = output_idx_t.scalar<int>()();
2531   const auto& func_list = node->attr().at("branches").list();
2532   if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
2533   NodeDef call_node = *node;
2534   call_node.set_op("PartitionedCall");
2535   call_node.clear_input();
2536   for (int i = 1; i < node->input_size(); ++i) {
2537     call_node.add_input(node->input(i));
2538   }
2539   auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
2540   *new_func = func_list.func(output_idx);
2541 
2542   // Move the output shape of the branch to _output_shapes if it is known.
2543   const auto& output_shape_list =
2544       (*node->mutable_attr())["output_shapes"].list();
2545   if (output_shape_list.shape_size() > output_idx) {
2546     TensorShapeProto* new_output_shape =
2547         (*call_node.mutable_attr())["_output_shapes"]
2548             .mutable_list()
2549             ->add_shape();
2550     *new_output_shape =
2551         std::move(node->attr().at("output_shapes").list().shape(output_idx));
2552   }
2553 
2554   call_node.mutable_attr()->erase("output_shapes");
2555   call_node.mutable_attr()->erase("branches");
2556 
2557   *node = std::move(call_node);
2558   return true;
2559 }
2560 
SimplifySelect(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2561 bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
2562                                      GraphDef* optimized_graph, NodeDef* node) {
2563   if (!IsSelect(*node)) return false;
2564   const std::vector<OpInfo::TensorProperties>& input_props =
2565       properties.GetInputProperties(node->name());
2566   if (input_props.size() < 3) return false;
2567   const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
2568   const bool is_all_true = IsOnes(*predicate_node);
2569   const bool is_all_false = IsZeros(*predicate_node);
2570   if (!is_all_true && !is_all_false) {
2571     return false;
2572   }
2573   const int live_input_idx = is_all_true ? 1 : 2;
2574   const int ignored_input_idx = is_all_true ? 2 : 1;
2575   const TensorShapeProto& predicate_shape = input_props[0].shape();
2576   const bool predicate_is_scalar =
2577       !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
2578   if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
2579       (ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
2580        predicate_is_scalar)) {
2581     // Replace node with Identity if no broadcasting is involved.
2582     node->set_op("Identity");
2583     *node->mutable_input(0) =
2584         AddControlDependency(node->input(0), optimized_graph, node_map_.get());
2585     *node->mutable_input(ignored_input_idx) = AddControlDependency(
2586         node->input(ignored_input_idx), optimized_graph, node_map_.get());
2587     node->mutable_input()->SwapElements(0, live_input_idx);
2588   } else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
2589                                               optimized_graph)) {
2590     return false;
2591   }
2592   DedupControlInputs(node);
2593   return true;
2594 }
2595 
RemoveRedundantVariableUpdates(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)2596 void ConstantFolding::RemoveRedundantVariableUpdates(
2597     GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
2598   static const absl::flat_hash_set<string>* kVariableReadOps =
2599       new absl::flat_hash_set<string>{"AssignAddVariableOp",
2600                                       "AssignSubVariableOp",
2601                                       "AssignAdd",
2602                                       "AssignSub",
2603                                       "ScatterAdd",
2604                                       "ScatterSub",
2605                                       "ScatterMul",
2606                                       "ScatterDiv",
2607                                       "ScatterNdAdd",
2608                                       "ScatterNdSub",
2609                                       "ScatterNdMul",
2610                                       "ScatterNdDiv",
2611                                       "ResourceScatterAdd",
2612                                       "ResourceScatterSub",
2613                                       "ResourceScatterMul",
2614                                       "ResourceScatterDiv",
2615                                       "ResourceScatterNdAdd",
2616                                       "ResourceScatterNdSub",
2617                                       "ResourceScatterNdMul",
2618                                       "ResourceScatterNdDiv"};
2619   if (kVariableReadOps == nullptr ||
2620       kVariableReadOps->find(node->op()) == kVariableReadOps->end())
2621     return;
2622   const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
2623   const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
2624   if (delta_node == nullptr) return;
2625   const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
2626                              absl::StrContains(node->op(), "Sub");
2627   if ((is_add_or_sub && IsZeros(*delta_node)) ||
2628       (!is_add_or_sub && IsOnes(*delta_node))) {
2629     VLOG(1) << "Removing redundant variable update: " << node->DebugString();
2630     if (absl::StrContains(node->op(), "Variable") ||
2631         absl::StrContains(node->op(), "Resource")) {
2632       ReplaceOperationWithNoOp(node, properties, optimized_graph);
2633     } else {
2634       ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
2635                                    optimized_graph);
2636     }
2637   }
2638 }
2639 
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2640 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2641                                              NodeDef* node) {
2642   if (!IsEnter(*node) || node->input_size() == 0 ||
2643       node->attr().count("is_constant") == 0 ||
2644       !node->attr().at("is_constant").b()) {
2645     return false;
2646   }
2647   const string& node_name = node->name();
2648   const NodeDef* input = node_map_->GetNode(node->input(0));
2649   if (input == nullptr || !IsReallyConstant(*input) ||
2650       OptimizedNodeExists(*input, "_enter")) {
2651     return false;
2652   }
2653   // Find non-constant nodes that consume the output of *node.
2654   std::vector<NodeDef*> consumers;
2655   for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
2656     if (!IsConstant(*fanout)) {
2657       for (int i = 0; i < fanout->input_size(); ++i) {
2658         if (fanout->input(i) == node_name) {
2659           consumers.push_back(const_cast<NodeDef*>(fanout));
2660           break;
2661         }
2662       }
2663     }
2664   }
2665   if (consumers.empty()) {
2666     return false;
2667   }
2668   graph_modified_ = true;
2669   NodeDef* new_node = optimized_graph->add_node();
2670   *new_node = *input;
2671   new_node->set_name(OptimizedNodeName(*input, "_enter"));
2672   new_node->set_device(node->device());
2673   new_node->clear_input();
2674   new_node->add_input(AsControlDependency(node_name));
2675   node_map_->AddNode(new_node->name(), new_node);
2676   node_map_->AddOutput(node_name, new_node->name());
2677   for (NodeDef* consumer : consumers) {
2678     for (int i = 0; i < consumer->input_size(); ++i) {
2679       if (NodeName(consumer->input(i)) == node_name) {
2680         node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
2681         consumer->set_input(i, new_node->name());
2682       }
2683     }
2684   }
2685   return true;
2686 }
2687 
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2688 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2689   if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2690       !OptimizedNodeExists(*node, "_const_false") &&
2691       !OptimizedNodeExists(*node, "_const_true")) {
2692     bool already_optimized = true;
2693     // If the optimization was already applied, the switch would have exactly
2694     // one Identity node consuming each of its outputs, each without any
2695     // non-control outputs.
2696     const auto& fanouts = node_map_->GetOutputs(node->name());
2697     if (fanouts.size() == 2) {
2698       for (const NodeDef* fanout : fanouts) {
2699         if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2700             HasRegularOutputs(*fanout, *node_map_)) {
2701           already_optimized = false;
2702           break;
2703         }
2704       }
2705     }
2706     Tensor false_t(DT_BOOL, TensorShape({}));
2707     Tensor true_t(DT_BOOL, TensorShape({}));
2708     // Make sure we don't proceed if this switch node was already optimized.
2709     if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2710         SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2711       // Copy the set of consumers of the switch as they will be manipulated
2712       // below.
2713       std::vector<NodeDef*> consumers =
2714           node_map_->GetOutputsOrderedByNodeName(node->name());
2715       // Create constant false & true nodes.
2716       NodeDef tmp_false_node;
2717       tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
2718       if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
2719                          &tmp_false_node)
2720                .ok()) {
2721         return false;
2722       }
2723       tmp_false_node.set_device(node->device());
2724       NodeDef tmp_true_node;
2725       tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
2726       if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
2727                          &tmp_true_node)
2728                .ok()) {
2729         return false;
2730       }
2731       tmp_true_node.set_device(node->device());
2732 
2733       // Add const nodes to graph.
2734       NodeDef* false_node = optimized_graph->add_node();
2735       false_node->Swap(&tmp_false_node);
2736       NodeDef* true_node = optimized_graph->add_node();
2737       true_node->Swap(&tmp_true_node);
2738 
2739       // Add controls from the switch ports to the constants, and connect the
2740       // constants to the original switch outputs.
2741       const string false_port = node->name();
2742       const string true_port = strings::StrCat(node->name(), ":1");
2743       const string false_ctrl_dep =
2744           AddControlDependency(false_port, optimized_graph, node_map_.get());
2745       false_node->add_input(false_ctrl_dep);
2746       const string true_ctrl_dep =
2747           AddControlDependency(true_port, optimized_graph, node_map_.get());
2748       true_node->add_input(true_ctrl_dep);
2749 
2750       node_map_->AddNode(false_node->name(), false_node);
2751       node_map_->AddNode(true_node->name(), true_node);
2752       node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2753       node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2754 
2755       for (NodeDef* consumer : consumers) {
2756         for (int i = 0; i < consumer->input_size(); ++i) {
2757           const string& input = consumer->input(i);
2758           if (input == false_port) {
2759             consumer->set_input(i, false_node->name());
2760             node_map_->UpdateInput(consumer->name(), false_port,
2761                                    false_node->name());
2762           } else if (input == true_port) {
2763             consumer->set_input(i, true_node->name());
2764             node_map_->UpdateInput(consumer->name(), true_port,
2765                                    true_node->name());
2766           }
2767         }
2768       }
2769       return true;
2770     }
2771   }
2772   return false;
2773 }
2774 
IsReductionWithConstantIndices(const NodeDef & node,bool * indices_is_empty) const2775 bool ConstantFolding::IsReductionWithConstantIndices(
2776     const NodeDef& node, bool* indices_is_empty) const {
2777   // Ensure its an appropriate Reduce node.
2778   if (!IsReduction(node) || node.input_size() < 2) {
2779     return false;
2780   }
2781   // Ensure that the axes to reduce by are constant.
2782   NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2783   if (!IsReallyConstant(*reductions_indices) ||
2784       !reductions_indices->attr().count("value")) {
2785     return false;
2786   }
2787   const TensorShapeProto& reduction_indices_shape =
2788       reductions_indices->attr().at("value").tensor().tensor_shape();
2789   *indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
2790   return true;
2791 }
2792 
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2793 bool ConstantFolding::IsReductionCandidateForSimplification(
2794     const NodeDef& node, const GraphProperties& properties,
2795     TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2796     bool* is_single_element_op) const {
2797   // Get the properties of the input & output tensors and check if they both
2798   // contain a single element.
2799   if (!properties.HasInputProperties(node.name()) ||
2800       !properties.HasOutputProperties(node.name())) {
2801     return false;
2802   }
2803   const auto& input_props = properties.GetInputProperties(node.name())[0];
2804   const auto& output_props = properties.GetOutputProperties(node.name())[0];
2805   if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2806       !output_props.has_shape() || output_props.shape().unknown_rank()) {
2807     return false;
2808   }
2809   *input_tensor_shape = input_props.shape();
2810   *output_tensor_shape = output_props.shape();
2811   for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2812     if (input_tensor_shape->dim(i).size() < 0) {
2813       return false;
2814     }
2815   }
2816   for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2817     if (output_tensor_shape->dim(i).size() < 0) {
2818       return false;
2819     }
2820   }
2821   const int input_num_elements =
2822       TensorShape(*input_tensor_shape).num_elements();
2823   const int output_num_elements =
2824       TensorShape(*output_tensor_shape).num_elements();
2825   *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2826 
2827   return true;
2828 }
2829 
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2830 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2831     const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2832     const TensorVector& reduction_indices_vector) const {
2833   int output_size = reduction_indices_vector[0]->NumElements();
2834   if (output_size == 0) {
2835     return true;
2836   }
2837 
2838   if (!keep_dims) {
2839     return false;
2840   }
2841   bool simplifiable = true;
2842   for (int i = 0; i < output_size; ++i) {
2843     int64_t dim;
2844     if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2845       dim = reduction_indices_vector[0]->flat<int32>()(i);
2846     } else {
2847       dim = reduction_indices_vector[0]->flat<int64_t>()(i);
2848     }
2849     if (dim < 0) {
2850       dim += input_shape.dim_size();
2851     }
2852     if (dim < 0 || dim >= input_shape.dim_size() ||
2853         input_shape.dim(dim).size() != 1) {
2854       simplifiable = false;
2855       break;
2856     }
2857   }
2858   return simplifiable;
2859 }
2860 
ReplaceReductionWithIdentity(NodeDef * node) const2861 bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
2862   // Replace the reduction node with an identity node, that can be further
2863   // optimized by other passes.
2864   DataType output_type;
2865   if (node->attr().count("T") != 0) {
2866     output_type = node->attr().at("T").type();
2867   } else if (IsAny(*node) || IsAll(*node)) {
2868     output_type = DT_BOOL;
2869   } else {
2870     return false;
2871   }
2872   node->set_op("Identity");
2873   EraseRegularNodeAttributes(node);
2874   (*node->mutable_attr())["T"].set_type(output_type);
2875   *node->mutable_input(1) = AsControlDependency(node->input(1));
2876   return true;
2877 }
2878 
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2879 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2880                                         const GraphProperties& properties,
2881                                         NodeDef* node) {
2882   bool indices_is_empty = false;
2883   if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
2884     return false;
2885   }
2886   if (indices_is_empty) {
2887     return ReplaceReductionWithIdentity(node);
2888   }
2889   bool is_single_element_op = false;
2890   TensorShapeProto input_tensor_shape, output_tensor_shape;
2891   if (!IsReductionCandidateForSimplification(
2892           *node, properties, &input_tensor_shape, &output_tensor_shape,
2893           &is_single_element_op)) {
2894     return false;
2895   }
2896 
2897   // Get the reduction indices.
2898   string reduction_indices_input = node->input(1);
2899   NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2900   TensorVector reduction_indices_vector;
2901   auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2902     for (const auto& out : reduction_indices_vector) {
2903       delete out.tensor;
2904     }
2905   });
2906   if (!EvaluateNode(*reduction_indices, TensorVector(),
2907                     &reduction_indices_vector)
2908            .ok() ||
2909       reduction_indices_vector.size() != 1) {
2910     return false;
2911   }
2912 
2913   bool keep_dims =
2914       node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2915   bool simplifiable_to_reshape =
2916       is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2917   bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2918       *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2919 
2920   if (simplifiable_to_reshape) {
2921     // Const node to output shape.
2922     const int new_num_dimensions = output_tensor_shape.dim_size();
2923     Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2924     for (int i = 0; i < new_num_dimensions; i++) {
2925       tensor.flat<int>()(i) = 1;
2926     }
2927     TensorValue shape_value(&tensor);
2928     NodeDef* shape_node = optimized_graph->add_node();
2929     if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2930                        shape_node)
2931              .ok()) {
2932       return false;
2933     }
2934     shape_node->set_device(node->device());
2935     node_map_->AddNode(shape_node->name(), shape_node);
2936     // Control dependency to ensure shape_node is in the correct frame.
2937     shape_node->add_input(AsControlDependency(reduction_indices_input));
2938     node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2939     // Optimize node to Reshape.
2940     node->set_op("Reshape");
2941     node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2942     node->set_input(1, shape_node->name());
2943     node->mutable_attr()->erase("keep_dims");
2944     node->mutable_attr()->erase("Tidx");
2945     AttrValue attr_type_indices;
2946     attr_type_indices.set_type(DT_INT32);
2947     (*node->mutable_attr())["Tshape"] = attr_type_indices;
2948     return true;
2949   } else if (simplifiable_to_identity) {
2950     return ReplaceReductionWithIdentity(node);
2951   }
2952   return false;
2953 }
2954 
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2955 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2956                                       bool use_shape_info, NodeDef* node) {
2957   if (!use_shape_info || node->attr().count("T") == 0 ||
2958       !IsSimplifiableReshape(*node, properties).ok()) {
2959     return false;
2960   }
2961   DataType output_type = node->attr().at("T").type();
2962   node->set_op("Identity");
2963   EraseRegularNodeAttributes(node);
2964   (*node->mutable_attr())["T"].set_type(output_type);
2965   *node->mutable_input(1) = AsControlDependency(node->input(1));
2966   return true;
2967 }
2968 
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2969 Status ConstantFolding::SimplifyArithmeticOperations(
2970     const GraphProperties& properties, bool use_shape_info,
2971     GraphDef* optimized_graph, NodeDef* node) {
2972   const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
2973   const bool is_matmul = IsAnyMatMul(*node);
2974   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2975   const bool is_sub = IsSub(*node);
2976   const bool is_any_div = IsAnyDiv(*node) && !IsFloorDiv(*node);
2977   // Simplify arithmetic operations with ones or zeros.
2978   if (use_shape_info &&
2979       (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2980       properties.HasInputProperties(node->name()) &&
2981       properties.HasOutputProperties(node->name())) {
2982     const NodeDef* x = node_map_->GetNode(node->input(0));
2983     const NodeDef* y = node_map_->GetNode(node->input(1));
2984     if (x == nullptr || y == nullptr) {
2985       return errors::InvalidArgument("Invalid inputs to node: ",
2986                                      node->DebugString());
2987     }
2988     const TensorShapeProto& output_shape =
2989         properties.GetOutputProperties(node->name())[0].shape();
2990 
2991     // Simplify element-wise multiplication by ones or addition/subtraction
2992     // of zeros.
2993     const TensorShapeProto& y_shape =
2994         properties.GetInputProperties(node->name())[1].shape();
2995     const TensorShapeProto& x_shape =
2996         properties.GetInputProperties(node->name())[0].shape();
2997     const bool y_matches_output_shape =
2998         ShapesSymbolicallyEqual(output_shape, y_shape);
2999     const bool x_matches_output_shape =
3000         ShapesSymbolicallyEqual(output_shape, x_shape);
3001 
3002     const bool x_is_zero = IsZeros(*x);
3003     const bool x_is_one = x_is_zero ? false : IsOnes(*x);
3004     if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
3005       // 1 * y = y or 0 + y = y.
3006       if (y_matches_output_shape) {
3007         ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
3008       } else if (x_matches_output_shape) {
3009         ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3010                                               optimized_graph);
3011       }
3012       return OkStatus();
3013     }
3014 
3015     if (y_matches_output_shape && (is_sub && x_is_zero)) {
3016       // Replace 0 - y with Neg(y).
3017       ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
3018       return OkStatus();
3019     }
3020 
3021     // Replace 1 / y with Reciprocal op.
3022     if (y_matches_output_shape && is_any_div && x_is_one) {
3023       TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
3024       DataType type = node->attr().at("T").type();
3025       if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
3026         ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
3027         return OkStatus();
3028       }
3029     }
3030 
3031     const bool y_is_zero = IsZeros(*y);
3032     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
3033     if (((is_mul || is_any_div) && y_is_one) ||
3034         ((is_add || is_sub) && y_is_zero)) {
3035       // x * 1 = x or x / 1 = x or x +/- 0 = x
3036       if (x_matches_output_shape) {
3037         ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
3038       } else if (y_matches_output_shape) {
3039         ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3040                                               optimized_graph);
3041       }
3042       return OkStatus();
3043     }
3044 
3045     // x OR true = true OR y = true.
3046     const PartialTensorShape shp(output_shape);
3047     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
3048       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3049           1, properties, output_shape, node, optimized_graph));
3050       return OkStatus();
3051     }
3052 
3053     // Simplify multiplication and matmul by zeros.
3054     // Also optimize zeros divided by a tensor, but only if we are in
3055     // aggressive mode, since we might get rid of divisions by zero.
3056     const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
3057     bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
3058     if ((x_is_zero || y_is_zero) &&
3059         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
3060       if (shp.IsFullyDefined()) {
3061         bool is_quantized = IsQuantizedMatMul(*node);
3062         TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3063             0, properties, output_shape, node, optimized_graph));
3064         if (is_quantized && graph_modified_) {
3065           TF_RETURN_IF_ERROR(
3066               AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
3067         }
3068         return OkStatus();
3069       }
3070       // Even if an input shape is only partially known, we may known that it
3071       // matches the output shape and thus forward or broadcast the
3072       // corresponding zero input.
3073       if ((is_mul || is_any_div) && x_is_zero) {
3074         if (x_matches_output_shape) {
3075           ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
3076         } else if (y_matches_output_shape) {
3077           ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3078                                                 optimized_graph);
3079         }
3080         return OkStatus();
3081       } else if (is_mul && y_is_zero) {
3082         if (y_matches_output_shape) {
3083           ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
3084         } else if (x_matches_output_shape) {
3085           ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3086                                                 optimized_graph);
3087         }
3088         return OkStatus();
3089       }
3090     }
3091   }
3092   return OkStatus();
3093 }
3094 
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)3095 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
3096                                                NodeDef* node) {
3097   // Strength reduce floating point division by a constant Div(x, const) to
3098   // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
3099   // will be constant folded to Mul(x, 1.0/const).
3100   if (node->input_size() >= 2 &&
3101       (IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
3102     const string& const_input = node->input(1);
3103     const NodeDef* denom = node_map_->GetNode(const_input);
3104     CHECK(denom != nullptr);
3105     if (!IsReallyConstant(*denom)) {
3106       return false;
3107     }
3108     if (node->attr().count("T") == 0) {
3109       return false;
3110     }
3111     DataType type = node->attr().at("T").type();
3112     // Skip integer division.
3113     if (IsDiv(*node) &&
3114         !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
3115       return false;
3116     }
3117     // Insert new reciprocal op and change node from Div to Mul.
3118     NodeDef* reciprocal_node = optimized_graph->add_node();
3119     reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
3120     reciprocal_node->set_op("Reciprocal");
3121     reciprocal_node->set_device(node->device());
3122     reciprocal_node->add_input(const_input);
3123     (*reciprocal_node->mutable_attr())["T"].set_type(type);
3124 
3125     // Re-wire inputs and outputs.
3126     if (IsXdivy(*node)) {
3127       node->set_op("MulNoNan");
3128       node->set_input(1, node->input(0));
3129       node->set_input(0, reciprocal_node->name());
3130     } else {
3131       node->set_op("Mul");
3132       node->set_input(1, reciprocal_node->name());
3133     }
3134     node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
3135     node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
3136 
3137     return true;
3138   }
3139 
3140   return false;
3141 }
3142 
PrepareConstantPushDown(const NodeDef & parent,const GraphProperties & properties,bool must_have_properties,ConstantPushDownContext * ctx) const3143 bool ConstantFolding::PrepareConstantPushDown(
3144     const NodeDef& parent, const GraphProperties& properties,
3145     bool must_have_properties, ConstantPushDownContext* ctx) const {
3146   if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
3147     return false;
3148   }
3149   NodeDef* left_child = node_map_->GetNode(parent.input(0));
3150   NodeDef* right_child = node_map_->GetNode(parent.input(1));
3151 
3152   // Sanity check for missing children.
3153   if (left_child == nullptr || right_child == nullptr) {
3154     return false;
3155   }
3156 
3157   ctx->left_child_is_const = IsReallyConstant(*left_child);
3158   ctx->right_child_is_const = IsReallyConstant(*right_child);
3159   ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
3160   ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
3161 
3162   // Nothing to do unless the parent has a constant child node.
3163   if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
3164     return false;
3165   }
3166 
3167   // Don't move nodes across devices.
3168   if (parent.device() != ctx->op_child->device() ||
3169       parent.device() != ctx->const_child->device()) {
3170     return false;
3171   }
3172 
3173   // Make sure that it is safe to change the value of the child node result.
3174   if (ctx->op_child->input_size() < 2 ||
3175       nodes_to_preserve_.find(ctx->op_child->name()) !=
3176           nodes_to_preserve_.end() ||
3177       NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
3178     return false;
3179   }
3180 
3181   // Don't apply reassociation to floating point types of low precision.
3182   // The danger of significant numerical changes is too high.
3183   if (!CheckAttrExists(parent, "T").ok()) return false;
3184   DataType dtype = parent.attr().at("T").type();
3185   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
3186     return false;
3187   }
3188 
3189   // Don't rewrite the tree if it might create cycles.
3190   // TODO(rmlarsen): Add back handling of control dependency from op to C.
3191   const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
3192   if (child_output.find(ctx->const_child) != child_output.end()) {
3193     return false;
3194   }
3195 
3196   // Get leaf nodes.
3197   ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
3198   ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
3199   ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
3200   ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
3201 
3202   if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
3203     // Child is already foldable, leave it alone.
3204     return false;
3205   }
3206 
3207   // Don't move nodes across devices.
3208   if (parent.device() != ctx->left_leaf->device() ||
3209       parent.device() != ctx->right_leaf->device()) {
3210     return false;
3211   }
3212 
3213   // Get shape and type information.
3214   ctx->parent_input_props = &properties.GetInputProperties(parent.name());
3215   ctx->op_child_input_props =
3216       &properties.GetInputProperties(ctx->op_child->name());
3217   if (must_have_properties && (ctx->parent_input_props == nullptr ||
3218                                ctx->parent_input_props->size() < 2 ||
3219                                ctx->op_child_input_props == nullptr ||
3220                                ctx->op_child_input_props->size() < 2)) {
3221     return false;
3222   }
3223 
3224   VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
3225           << parent.op() << "(" << left_child->op() << ", " << right_child->op()
3226           << ")";
3227 
3228   return true;
3229 }
3230 
ConstantPushDownBiasAdd(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3231 bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
3232                                               GraphDef* optimized_graph,
3233                                               NodeDef* node) {
3234   // This implements constant push-down for BiasAdd. In the following "CV" is a
3235   // constant vector (tensor of rank 1), "V" is a (possibly) non-constant
3236   // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3237   // non-constant matrix, and "BA" is BiasAdd.
3238   // For a valid input graph, the following 4 rewrites are legal:
3239   //
3240   //  1)                  +                +
3241   //                     / \              / \
3242   //                    BA  CV    -- >   BA  V
3243   //                   / \              / \
3244   //                  M   V            M   CV
3245   //
3246   //  2)                  +                +
3247   //                     / \              / \
3248   //                    BA  CM    -- >   BA  M
3249   //                   / \              / \
3250   //                  M   V            CM  V
3251   //
3252   //  3)                  BA               BA
3253   //                     / \              / \
3254   //                    +  CV     -- >   +   V
3255   //                   / \              / \
3256   //                  M   V            M  CV
3257   //
3258   //  4)                  BA               BA      = parent
3259   //                     / \              / \
3260   //                    BA  CV    -- >   BA  V     = children
3261   //                   / \              / \
3262   //                  M   V            M  CV       = leaves
3263   //
3264   // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3265 
3266   const bool parent_is_bias_add = IsBiasAdd(*node);
3267   if (!parent_is_bias_add && !IsAdd(*node)) return false;
3268   ConstantPushDownContext ctx;
3269   if (!PrepareConstantPushDown(*node, *properties,
3270                                /*must_have_properties=*/true, &ctx)) {
3271     return false;
3272   }
3273   // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3274   // >= 2 and the leaves must be vectors, we cannot swap them.
3275   if (ctx.left_child_is_const && parent_is_bias_add) return false;
3276   const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
3277   if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
3278 
3279   // Get properties to validate rank and dtype constraints.
3280   if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
3281       (*ctx.parent_input_props)[0].shape().unknown_rank() ||
3282       (*ctx.parent_input_props)[1].shape().unknown_rank() ||
3283       (*ctx.op_child_input_props)[0].shape().unknown_rank() ||
3284       (*ctx.op_child_input_props)[1].shape().unknown_rank()) {
3285     return false;
3286   }
3287 
3288   // Now get the ranks and types of the 3 leaf nodes.
3289   const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
3290   const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
3291   // At least one leaf must be a vector.
3292   if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
3293   const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3294   const int matrix_idx = 1 - vector_idx;
3295 
3296   const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
3297   const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
3298   if (vector_rank != 1) return false;  // this should never happen.
3299   const DataType vector_type = vector_prop.dtype();
3300 
3301   const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
3302   const int matrix_rank = matrix_prop.shape().dim_size();
3303   const DataType matrix_type = matrix_prop.dtype();
3304 
3305   const int const_idx = ctx.left_child_is_const ? 0 : 1;
3306   const auto& const_prop = (*ctx.parent_input_props)[const_idx];
3307   const int const_rank = const_prop.shape().dim_size();
3308   const DataType const_type = const_prop.dtype();
3309 
3310   int input_to_swap = -1;
3311 
3312   if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
3313       const_type == matrix_type) {
3314     // Case 2:
3315     input_to_swap = matrix_idx;
3316   } else if (const_rank == 1 && const_type == vector_type) {
3317     // Case 1, 3, and, 4:
3318     input_to_swap = vector_idx;
3319   }
3320   if (input_to_swap == -1) return false;
3321   const NodeDef* leaf_to_swap =
3322       node_map_->GetNode(ctx.op_child->input(input_to_swap));
3323   if (IsConstant(*leaf_to_swap)) return false;
3324 
3325   node_map_->UpdateInput(node->name(), node->input(const_idx),
3326                          ctx.op_child->input(input_to_swap));
3327   node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
3328   if (ctx.op_child->input(input_to_swap) !=
3329       ctx.op_child->input(1 - input_to_swap)) {
3330     node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
3331                             ctx.op_child->name());
3332   }
3333   std::swap(*node->mutable_input(const_idx),
3334             *ctx.op_child->mutable_input(input_to_swap));
3335   properties->ClearInputProperties(node->name());
3336   properties->ClearInputProperties(ctx.op_child->name());
3337 
3338   return true;
3339 }
3340 
ConstantPushDown(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3341 bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
3342                                        GraphDef* optimized_graph,
3343                                        NodeDef* node) {
3344   // Consider the transformation
3345   //
3346   //                      +                +       = parent
3347   //                     / \              / \
3348   //                    C   +    -- >    X   +     = children
3349   //                       / \              / \
3350   //                      X   Y            C   Y   = leaves
3351   //
3352   // where C is constant, X is non-constant, Y may be constant or non-constant,
3353   // and '+' denotes an associative and commutative operator like addition or
3354   // multiplication. This optimization pushes constants down in the tree to
3355   // canonicalize it. Moreover, in cases where the child node has a second
3356   // constant input Y we will create a leaf node that can be folded, e.g.
3357   //
3358   //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
3359   //
3360   // We also handle the non-commutative cases of subtraction and division
3361   // by rotating the tree locally, e.g.
3362   //    Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
3363   //    Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
3364 
3365   // Get parent op type.
3366   const bool is_add = IsAdd(*node);
3367   const bool is_mul = IsMul(*node);
3368   const bool is_sub = IsSub(*node);
3369   const bool is_div = IsDiv(*node);
3370   if (!(is_add || is_sub || is_mul || is_div)) return false;
3371   const bool is_symmetric = is_add || is_mul;
3372 
3373   ConstantPushDownContext ctx;
3374   if (!PrepareConstantPushDown(*node, *properties,
3375                                /*must_have_properties=*/false, &ctx)) {
3376     return false;
3377   }
3378 
3379   // Get child op type.
3380   const bool is_child_add = IsAdd(*ctx.op_child);
3381   const bool is_child_mul = IsMul(*ctx.op_child);
3382   const bool is_child_sub = IsSub(*ctx.op_child);
3383   const bool is_child_div = IsDiv(*ctx.op_child);
3384   const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
3385   const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
3386   if (!is_add_sub && !is_mul_div) {
3387     return false;
3388   }
3389   const bool is_child_symmetric = is_child_add || is_child_mul;
3390 
3391   if (!CheckAttrExists(*node, "T").ok()) return false;
3392   DataType dtype = node->attr().at("T").type();
3393   if (!(is_symmetric && is_child_symmetric) &&
3394       !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
3395     return false;
3396   }
3397 
3398   const NodeDef* y_node =
3399       ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
3400   if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
3401       !ctx.op_child_input_props->empty()) {
3402     // If we know the shapes of the nodes being swapped, make sure we don't push
3403     // down a larger node and create more work by broadcasting earlier in the
3404     // expressions tree.
3405     const PartialTensorShape c_shape(
3406         (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
3407     const PartialTensorShape x_shape(
3408         (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
3409 
3410     if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
3411         c_shape.num_elements() > x_shape.num_elements()) {
3412       return false;
3413     } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
3414                c_shape.dims() > 0) {
3415       for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
3416         if (x_shape.dim_size(idx) >= 0 &&
3417             c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
3418           return false;
3419         }
3420       }
3421     }
3422   }
3423 
3424   // Get the node names corresponding to X, Y, and C.
3425   const string input_x =
3426       ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
3427   const string input_y = input_x == ctx.op_child->input(0)
3428                              ? ctx.op_child->input(1)
3429                              : ctx.op_child->input(0);
3430   const string input_c =
3431       ctx.left_child_is_const ? node->input(0) : node->input(1);
3432   const string input_op =
3433       ctx.left_child_is_const ? node->input(1) : node->input(0);
3434   VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
3435 
3436   // Now we have identified the nodes to swap, update the nodemap accordingly.
3437   node_map_->UpdateInput(node->name(), input_c, input_x);
3438   node_map_->AddOutput(input_c, ctx.op_child->name());
3439   if (input_x != input_y) {
3440     node_map_->RemoveOutput(input_x, ctx.op_child->name());
3441   }
3442   properties->ClearInputProperties(node->name());
3443   properties->ClearInputProperties(ctx.op_child->name());
3444 
3445   if (is_symmetric && is_child_symmetric) {
3446     // Easy case (only commutative ops). We always write this as one of
3447     //   +
3448     //  / \
3449     // X   +
3450     //    / \
3451     //   C   Y
3452     node->set_input(0, input_x);
3453     node->set_input(1, input_op);
3454     ctx.op_child->set_input(0, input_c);
3455     ctx.op_child->set_input(1, input_y);
3456   } else {
3457     // More complicated case: When there are non-commutative operations like
3458     // subtractions or divisions involved, we may have to rotate the tree
3459     // and/or change op types. There are 6 non-trivial cases depending on
3460     // the effective generalized "sign" of each of the three terms C, Y, and X.
3461     // Here are the final trees we want to generate for those 6 cases:
3462     //
3463     // (CYX signs):   ++-      +--      -+-    --+     +-+      -++
3464     //
3465     //                 -        -        -      -       +        +
3466     //                / \      / \      / \    / \     / \      / \
3467     //               +   X    -   X    -   X  X   +   X   -    X   -
3468     //              / \      / \      / \        / \     / \      / \
3469     //             C   Y    C   Y    Y   C      Y   C   C   Y    Y   C
3470     //
3471 
3472     // First, let's determine the effective sign of each term in the original
3473     // expression
3474     auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
3475       bool leaf_negated = !is_child_symmetric && is_right_leaf;
3476       bool child_negated = !is_symmetric && (ctx.left_child_is_const);
3477       return leaf_negated != child_negated;
3478     };
3479     const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
3480     const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
3481     bool neg_c = !is_symmetric && !ctx.left_child_is_const;
3482     bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
3483     bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
3484     // Rewrite the parent node.
3485     node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
3486     node->set_input(0, neg_x ? input_op : input_x);
3487     node->set_input(1, neg_x ? input_x : input_op);
3488     // Rewrite the child node.
3489     ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
3490     ctx.op_child->set_input(0, neg_c ? input_y : input_c);
3491     ctx.op_child->set_input(1, neg_c ? input_c : input_y);
3492   }
3493   return true;
3494 }
3495 
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)3496 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
3497                                       const GraphProperties& properties) {
3498   // Push down multiplication on ConvND.
3499   //                       *                  ConvND
3500   //                     /   \                /    \
3501   //                 ConvND  C2    -- >      X      *
3502   //                  / \                          / \
3503   //                 X  C1                       C1  C2
3504   //
3505   // where C1 and C2 are constants and X is non-constant.
3506   //
3507   // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
3508 
3509   if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
3510 
3511   NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
3512   NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
3513   if (mul_left_child == nullptr || mul_right_child == nullptr) {
3514     return false;
3515   }
3516   // One child must be constant, and the second must be Conv op.
3517   const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
3518   const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
3519   if (!left_child_is_constant && !right_child_is_constant) {
3520     return false;
3521   }
3522   NodeDef* conv_node =
3523       left_child_is_constant ? mul_right_child : mul_left_child;
3524   if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
3525     return false;
3526   }
3527   if (node->device() != mul_left_child->device() ||
3528       node->device() != mul_right_child->device()) {
3529     return false;
3530   }
3531 
3532   // Make sure that it is safe to change the value of the convolution
3533   // output.
3534   if (conv_node->input_size() < 2 ||
3535       NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
3536       nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
3537     return false;
3538   }
3539 
3540   // Identify the nodes to swap.
3541   NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
3542   NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
3543   const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
3544   const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
3545   if (!conv_left_is_constant && !conv_right_is_constant) {
3546     // At least one of the convolution inputs should be constant.
3547     return false;
3548   }
3549   if (conv_left_is_constant && conv_right_is_constant) {
3550     // Leverage regular constant folding to handle this.
3551     return false;
3552   }
3553   const auto& mul_props = properties.GetOutputProperties(node->name());
3554   const auto& conv_props = properties.GetOutputProperties(conv_node->name());
3555   if (mul_props.empty() || conv_props.empty()) {
3556     return false;
3557   }
3558   const auto& mul_shape = mul_props[0].shape();
3559   const auto& conv_shape = conv_props[0].shape();
3560   if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
3561     return false;
3562   }
3563 
3564   const auto& input_props = properties.GetInputProperties(conv_node->name());
3565   if (input_props.size() < 2) {
3566     return false;
3567   }
3568   const auto& filter_shape = input_props[1].shape();
3569 
3570   NodeDef* const_node =
3571       left_child_is_constant ? mul_left_child : mul_right_child;
3572   const auto& const_props = properties.GetOutputProperties(const_node->name());
3573   if (const_props.empty()) {
3574     return false;
3575   }
3576   const auto& const_shape = const_props[0].shape();
3577   if (!IsValidConstShapeForMulConvPushDown(
3578           conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
3579     return false;
3580   }
3581 
3582   string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
3583   if (node_map_->NodeExists(mul_new_name)) {
3584     return false;
3585   }
3586   // Make sure we don't introduce loops in the graph by removing control
3587   // dependencies from the conv2d node to c2.
3588   string conv_const_input =
3589       conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
3590   if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
3591                               node_map_.get())) {
3592     // Add a control dep from c1 to c2 to ensure c2 is in the right frame
3593     MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
3594                          node_map_.get());
3595   }
3596 
3597   conv_node->set_name(node->name());
3598   node->set_name(mul_new_name);
3599   if (conv_left_is_constant) {
3600     node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
3601     conv_node->set_input(0, mul_new_name);
3602   } else {
3603     node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
3604     conv_node->set_input(1, mul_new_name);
3605   }
3606   NodeDef* conv_const_node =
3607       conv_left_is_constant ? conv_left_child : conv_right_child;
3608   if (left_child_is_constant) {
3609     node->set_input(1, conv_const_node->name());
3610   } else {
3611     node->set_input(0, conv_const_node->name());
3612   }
3613   node_map_->AddNode(mul_new_name, node);
3614 
3615   return true;
3616 }
3617 
PartialConstPropThroughIdentityN(NodeDef * node)3618 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3619   // Partial constant propagation through IdentityN.
3620   if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
3621       !HasRegularInputs(*node))
3622     return false;
3623 
3624   std::vector<int> inputs_to_forward;
3625   for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3626     const string& input = node->input(input_idx);
3627     if (IsControlInput(input)) {
3628       return false;
3629     }
3630     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3631     if (input_node == nullptr) {
3632       LOG(ERROR) << "Bad input: " << input;
3633       return false;
3634     }
3635     // Forward constant inputs to outputs and add a control dependency on
3636     // the IdentityN node.
3637     if (IsReallyConstant(*input_node)) {
3638       inputs_to_forward.push_back(input_idx);
3639     }
3640   }
3641   return ForwardInputs(node, inputs_to_forward);
3642 }
3643 
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3644 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3645                                                  GraphProperties* properties,
3646                                                  NodeDef* node) {
3647   // Partial constant folding for associative operators:
3648   // Split AddN/AccumulateNV2 to enable partial
3649   // folding of ops when more than one but not all inputs are constant.
3650   // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3651   // addition is commutative.
3652   if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
3653 
3654   const int num_non_control_inputs = NumNonControlInputs(*node);
3655   if (num_non_control_inputs <= 2) return false;
3656   const int num_control_inputs = node->input_size() - num_non_control_inputs;
3657   std::vector<int> const_inputs;
3658   std::vector<int> nonconst_inputs;
3659   for (int i = 0; i < node->input_size(); ++i) {
3660     const string& input = node->input(i);
3661     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3662     if (input_node == nullptr) return false;
3663     if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3664       const_inputs.push_back(i);
3665     } else {
3666       // Non-const and control inputs.
3667       nonconst_inputs.push_back(i);
3668     }
3669   }
3670   // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3671   // a fake node that cannot be constant folded by itself.
3672   int const_inputs_size = const_inputs.size();
3673   if (const_inputs_size == num_non_control_inputs &&
3674       node->op() == "AccumulateNV2") {
3675     node->set_op("AddN");
3676     node->mutable_attr()->erase("shape");
3677     return true;
3678   }
3679   const string new_node_name = OptimizedNodeName(
3680       *node, strings::StrCat("_partial_split_", const_inputs_size));
3681   if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
3682       !node_map_->NodeExists(new_node_name)) {
3683     NodeDef* added_node = optimized_graph->add_node();
3684     *added_node = *node;
3685     // Always use AddN for the constant node, since AccumulateNV2 is a fake
3686     // node that cannot be constant folded, since it does not have a kernel.
3687     added_node->set_op("AddN");
3688     added_node->mutable_attr()->erase("shape");
3689     added_node->set_name(new_node_name);
3690     node_map_->AddNode(added_node->name(), added_node);
3691     added_node->clear_input();
3692     for (int i : const_inputs) {
3693       added_node->add_input(node->input(i));
3694       node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3695                               added_node->name());
3696     }
3697 
3698     // Overwrite the first const input with the added node.
3699     node->set_input(const_inputs[0], added_node->name());
3700     node_map_->AddOutput(added_node->name(), node->name());
3701     nonconst_inputs.push_back(const_inputs[0]);
3702     // Compact the remaining inputs to the original node.
3703     std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3704     int idx = 0;
3705     for (int i : nonconst_inputs) {
3706       if (idx != i) {
3707         node->set_input(idx, node->input(i));
3708       }
3709       ++idx;
3710     }
3711     node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3712                                           const_inputs.size() - 1);
3713     (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
3714     properties->ClearInputProperties(node->name());
3715     (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3716     return true;
3717   }
3718   return false;
3719 }
3720 
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3721 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3722                                                 GraphProperties* properties,
3723                                                 NodeDef* node) {
3724   // Partial constant folding for Concat which is not commutative, so
3725   // we have to preserve order and can only push consecutive runs of constant
3726   // inputs into sub-nodes.
3727   if (!IsConcat(*node) ||
3728       node->name().rfind("_partial_split_") != string::npos) {
3729     return false;
3730   }
3731   const int num_non_control_inputs = NumNonControlInputs(*node);
3732   if (num_non_control_inputs <= 3) return false;
3733   int axis_arg = -1;
3734   int begin = 0;
3735   int end = num_non_control_inputs;
3736   if (node->op() == "Concat") {
3737     begin = 1;
3738     axis_arg = 0;
3739   } else if (node->op() == "ConcatV2") {
3740     end = num_non_control_inputs - 1;
3741     axis_arg = num_non_control_inputs - 1;
3742   } else {
3743     return false;
3744   }
3745 
3746   // We search for consecutive runs of constant inputs in the range
3747   // [begin:end[ and push then down into child nodes.
3748   std::vector<std::pair<int, int>> constant_input_runs;
3749   int first = begin;
3750   int last = begin;
3751   while (last < end) {
3752     while (first < end && !IsReallyConstant(*node_map_->GetNode(
3753                               NodeName(node->input(first))))) {
3754       ++first;
3755     }
3756     // Invariant: node[first] is constant || first >= end.
3757     last = first + 1;
3758     while (last < end &&
3759            IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
3760       ++last;
3761     }
3762     // Invariant: node[last] is not constant || last >= end
3763     // Discard intervals shorter than 2 elements.
3764     if (first < end && (last - first) > 1) {
3765       constant_input_runs.emplace_back(first, last);
3766     }
3767     first = last;
3768   }
3769 
3770   // Skip if all inputs are constant, and let constant folding take over.
3771   if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3772                                       constant_input_runs[0].first == begin &&
3773                                       constant_input_runs[0].second == end)) {
3774     return false;
3775   }
3776   std::set<int> inputs_to_delete;
3777   for (auto interval : constant_input_runs) {
3778     // Push the constant inputs in the interval to a child node than can be
3779     // constant folded.
3780     string new_node_name = OptimizedNodeName(*node, "_partial_split");
3781     do {
3782       new_node_name += strings::StrCat("_", interval.first);
3783     } while (node_map_->NodeExists(new_node_name));
3784 
3785     NodeDef* added_node = optimized_graph->add_node();
3786     *added_node = *node;
3787     added_node->set_op("ConcatV2");
3788     added_node->set_name(new_node_name);
3789     node_map_->AddNode(added_node->name(), added_node);
3790     added_node->clear_input();
3791     for (int i = interval.first; i < interval.second; ++i) {
3792       added_node->add_input(node->input(i));
3793       node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
3794       if (i != interval.first) {
3795         inputs_to_delete.insert(i);
3796       }
3797     }
3798     added_node->add_input(node->input(axis_arg));
3799     (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
3800     node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3801 
3802     // Overwrite the first constant input with the result of the added
3803     // child node.
3804     node->set_input(interval.first, added_node->name());
3805   }
3806   if (!inputs_to_delete.empty()) {
3807     // Fix up the inputs to the original node.
3808     protobuf::RepeatedPtrField<string> tmp;
3809     tmp.Swap(node->mutable_input());
3810     for (int i = 0; i < tmp.size(); ++i) {
3811       if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3812         node->add_input(tmp.Get(i));
3813       }
3814     }
3815     (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3816     properties->ClearInputProperties(node->name());
3817   }
3818   return true;
3819 }
3820 
GetConcatAxis(const NodeDef & node,int * axis)3821 bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
3822   if (node.op() != "ConcatV2") {
3823     return false;
3824   }
3825   int axis_idx = node.input_size() - 1;
3826   while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
3827     --axis_idx;
3828   }
3829   if (axis_idx <= 0) {
3830     return false;
3831   }
3832   Tensor axis_tensor;
3833   if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
3834     return false;
3835   }
3836   *axis = axis_tensor.dtype() == DT_INT64
3837               ? static_cast<int>(axis_tensor.scalar<int64_t>()())
3838               : axis_tensor.scalar<int32>()();
3839   return true;
3840 }
3841 
MergeConcat(bool use_shape_info,GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3842 bool ConstantFolding::MergeConcat(bool use_shape_info,
3843                                   GraphProperties* properties,
3844                                   GraphDef* optimized_graph, NodeDef* node) {
3845   // We only optimize for ConcatV2.
3846   int axis;
3847   if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
3848       nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3849       node_map_->GetOutputs(node->name()).size() != 1) {
3850     return false;
3851   }
3852 
3853   // If all inputs are constant, don't merge and let folding take case of it.
3854   const int num_regular_inputs = NumNonControlInputs(*node);
3855   bool all_inputs_are_const = true;
3856   for (int i = 0; i < num_regular_inputs - 1; ++i) {
3857     const NodeDef* input_node = node_map_->GetNode(node->input(i));
3858     if (!IsReallyConstant(*input_node)) {
3859       all_inputs_are_const = false;
3860       break;
3861     }
3862   }
3863   if (all_inputs_are_const) return false;
3864 
3865   NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3866   int parent_axis;
3867   if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
3868     return false;
3869   }
3870 
3871   // Make a pass over the parent inputs to see if any of them have explicit
3872   // device() fields set, and if different inputs are on different tasks.  If
3873   // so, this concat of concats may have been carefully constructed to be a
3874   // two-stage concat, and we don't want to undo that here.
3875   string task, device;
3876   absl::flat_hash_set<string> unique_input_tasks;
3877   const int n_parent_inputs = NumNonControlInputs(*parent);
3878   // Iterate over the real inputs to concatenate [0..n_parent_inputs - 1).  The
3879   // input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
3880   // node, which we don't want to consider here.
3881   for (int i = 0; i < n_parent_inputs - 1; ++i) {
3882     const NodeDef* input_node = node_map_->GetNode(parent->input(i));
3883     if (!input_node->device().empty() &&
3884         tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
3885                                                      &task, &device)) {
3886       unique_input_tasks.insert(task);
3887       if (unique_input_tasks.size() >= 2) {
3888         // More than one input task represented in the device specifications
3889         // of the parent's input nodes.  Don't mess with this.
3890         return false;
3891       }
3892     }
3893   }
3894 
3895   protobuf::RepeatedPtrField<string> parent_inputs;
3896   parent_inputs.Swap(parent->mutable_input());
3897   // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
3898   // collapse it into the parent multiple times? Probably not.
3899   for (const auto& input : parent_inputs) {
3900     if (IsSameInput(input, node->name())) {
3901       for (int j = 0; j < num_regular_inputs - 1; ++j) {
3902         // Add tensor inputs to first child concat tensors (except the final
3903         // axis input) to the parent's inputs.
3904         parent->add_input(node->input(j));
3905         node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
3906       }
3907     } else {
3908       parent->add_input(input);
3909     }
3910   }
3911   // Forward Add control inputs
3912   const int num_inputs = node->input_size();
3913   for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
3914     parent->add_input(node->input(i));
3915     node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
3916     node->mutable_input()->RemoveLast();
3917   }
3918   (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3919   DedupControlInputs(parent);
3920   ReplaceOperationWithNoOp(node, properties, optimized_graph);
3921 
3922   return true;
3923 }
3924 
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3925 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3926     NodeDef* node, GraphDef* optimized_graph) {
3927   auto add_quantized_out = [this, node, optimized_graph](
3928                                const string& out_const_name, int index) {
3929     NodeDef* out_node = optimized_graph->add_node();
3930     graph_modified_ = true;
3931     Tensor value(DT_FLOAT, TensorShape({}));
3932     const bool is_min = index == 1;
3933     const DataType type_attr = node->attr().at("dtype").type();
3934 
3935     value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3936                                     : QuantizedTypeMaxAsFloat(type_attr);
3937     TF_RETURN_IF_ERROR(
3938         CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3939     node_map_->AddNode(out_const_name, out_node);
3940     out_node->set_device(node->device());
3941     // Copy all inputs from node.
3942     out_node->mutable_input()->CopyFrom(node->input());
3943     for (const string& input : out_node->input()) {
3944       node_map_->AddOutput(NodeName(input), out_const_name);
3945     }
3946 
3947     // Update output nodes consuming node:index to new const node.
3948     string old_input = absl::StrCat(node->name(), ":", index);
3949     int old_node_count = 0;
3950     // We make a copy since the set might change.
3951     auto outputs = node_map_->GetOutputs(node->name());
3952     for (const auto& output : outputs) {
3953       for (int i = 0; i < output->input_size(); ++i) {
3954         if (output->input(i) == old_input) {
3955           output->set_input(i, out_const_name);
3956           node_map_->AddOutput(out_const_name, output->name());
3957         } else if (NodeName(output->input(i)) == node->name()) {
3958           ++old_node_count;
3959         }
3960       }
3961       if (old_node_count == 0) {
3962         node_map_->RemoveOutput(node->name(), output->name());
3963       }
3964     }
3965 
3966     return OkStatus();
3967   };
3968   const string min_out_const_name =
3969       OptimizedNodeName(*node, "-quantized_matmul_min_out");
3970   const string max_out_const_name =
3971       OptimizedNodeName(*node, "-quantized_matmul_max_out");
3972   if (node_map_->GetNode(min_out_const_name) == nullptr &&
3973       node_map_->GetNode(max_out_const_name) == nullptr) {
3974     TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3975     TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3976   } else {
3977     return errors::Internal(absl::Substitute(
3978         "Can't create Const for QuantizedMatMul min_out/max_out of "
3979         "node '$0' because of node name conflict",
3980         node->name()));
3981   }
3982   return OkStatus();
3983 }
3984 
RunOptimizationPass(Cluster * cluster,GrapplerItem * item,GraphProperties * properties,GraphDef * optimized_graph)3985 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3986                                             GrapplerItem* item,
3987                                             GraphProperties* properties,
3988                                             GraphDef* optimized_graph) {
3989   optimized_graph->Clear();
3990   graph_ = &item->graph;
3991   node_map_.reset(new NodeMap(graph_));
3992   nodes_allowlist_.clear();
3993   // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3994   // has a single fanout, it would be rewritten as a constant with the same
3995   // node name, and therefore users are still able to fetch it. This is not
3996   // the case if the node has multiple fanouts, and constant folding would
3997   // replace the node with multiple constants (each for one fanout) with
3998   // new names, and as a result users would not be able to fetch the node any
3999   // more with the original node name.
4000   for (const auto& fetch : item->fetch) {
4001     const NodeDef* fetch_node = node_map_->GetNode(fetch);
4002     if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
4003       nodes_allowlist_.insert(fetch_node->name());
4004     }
4005   }
4006 
4007   absl::flat_hash_set<string> nodes_to_not_simplify;
4008   if (properties->has_properties()) {
4009     TF_RETURN_IF_ERROR(MaterializeShapes(*properties));
4010     TF_RETURN_IF_ERROR(MaterializeConstants(*properties));
4011     TF_RETURN_IF_ERROR(
4012         FoldGraph(*properties, optimized_graph, &nodes_to_not_simplify));
4013   } else {
4014     *optimized_graph = *graph_;
4015   }
4016   node_map_.reset(new NodeMap(optimized_graph));
4017 
4018   TF_RETURN_IF_ERROR(
4019       SimplifyGraph(optimized_graph, properties, &nodes_to_not_simplify));
4020 
4021   return OkStatus();
4022 }
4023 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)4024 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
4025                                  GraphDef* optimized_graph) {
4026   // TensorFlow flushes denormals to zero and rounds to nearest, so we do
4027   // the same here.
4028   port::ScopedFlushDenormal flush;
4029   port::ScopedSetRound round(FE_TONEAREST);
4030   nodes_to_preserve_ = item.NodesToPreserve();
4031   for (const auto& feed : item.feed) {
4032     feed_nodes_.insert(NodeName(feed.first));
4033   }
4034 
4035   if (cpu_device_ == nullptr) {
4036     owned_device_.reset(new DeviceSimple());
4037     cpu_device_ = owned_device_.get();
4038   }
4039 
4040   graph_contains_assign_or_inplace_op_ = false;
4041   for (const NodeDef& node : item.graph.node()) {
4042     if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
4043       graph_contains_assign_or_inplace_op_ = true;
4044       break;
4045     }
4046   }
4047 
4048   has_fetch_ = !item.fetch.empty();
4049   GrapplerItem item_to_optimize = item;
4050   GraphProperties properties(item_to_optimize);
4051   // It's possible to feed a placeholder with a tensor of any shape: make sure
4052   // that the shape inference deals with this conservatively unless we're in
4053   // aggressive mode.
4054   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
4055   if (!properties
4056            .InferStatically(assume_valid_feeds,
4057                             /*aggressive_shape_inference=*/false,
4058                             /*include_input_tensor_values=*/false,
4059                             /*include_output_tensor_values=*/true)
4060            .ok()) {
4061     properties.Clear();
4062   }
4063 
4064   *optimized_graph = GraphDef();
4065   item_to_optimize.graph.Swap(optimized_graph);
4066   int64_t node_count;
4067 
4068   do {
4069     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4070     graph_modified_ = false;
4071     item_to_optimize.graph.Swap(optimized_graph);
4072     node_count = item_to_optimize.graph.node_size();
4073     TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, &item_to_optimize,
4074                                            &properties, optimized_graph));
4075   } while (graph_modified_ || optimized_graph->node_size() != node_count);
4076   *optimized_graph->mutable_library() = item.graph.library();
4077   *optimized_graph->mutable_versions() = item.graph.versions();
4078 
4079   return OkStatus();
4080 }
4081 
4082 }  // namespace grappler
4083 }  // namespace tensorflow
4084